diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc index ba0a74f1b4c654870e99b9f6bdc43e86964206f3..a84e9467a25c92e119f12f0b52bcfe1f73aa8917 100755 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc @@ -87,7 +87,7 @@ public: std::string tags; std::map> vec_feas; int sample_type; - std::map> auc_category_info_map; //为细维度计算auc准备的数据 + std::map> auc_category_info_map; //为细维度计算auc准备的数据 std::vector hot_feas, cold_feas; //冷(int32_t)热(uint64_t)feasign void clear() { @@ -149,7 +149,7 @@ public: return cursor; } - void parse_feas(char* buffer) const { + void serialize_to_compress_feas(char* buffer) const { if (buffer == nullptr) { return ; } @@ -293,16 +293,21 @@ public: std::string& single_category_str = all_category_vec[i]; std::vector str_vec = paddle::string::split_string(single_category_str, "="); CHECK(str_vec.size() == 2); - /*std::string category_name = str_vec[0]; - std::vector category_info_vec = paddle::string::split_string(str_vec[1], ","); + std::string category_name = str_vec[0]; + std::vector category_info_vec = paddle::string::split_string(str_vec[1], ","); CHECK(category_info_vec.size() > 0); - CHECK(rec.auc_category_info_map.insert({category_name, category_info_vec}).second);*/ + CHECK(rec.auc_category_info_map.insert({category_name, category_info_vec}).second); } } else { uint64_t sign = 0; int slot = -1; - CHECK((sign = (uint64_t) strtoull(str, &cursor, 10), cursor != str)); + sign = (uint64_t)strtoull(str, &cursor, 10); + if (cursor == str) { //FIXME abacus没有这种情况 + str++; + continue; + } + //CHECK((sign = (uint64_t)strtoull(str, &cursor, 10), cursor != str)); str = cursor; CHECK(*str++ == ':'); CHECK(!isspace(*str)); @@ -320,17 +325,12 @@ public: } paddle::framework::BinaryArchive bar; - bar << rec.show; - bar << rec.clk; - bar << rec.tags; - bar << rec.vec_feas; - bar << rec.sample_type; - bar << rec.auc_category_info_map; - uint32_t feas_len = rec.calc_compress_feas_lens(); + bar << rec.show << rec.clk << rec.tags << rec.vec_feas << rec.sample_type << rec.auc_category_info_map; + uint32_t feas_len = rec.calc_compress_feas_lens(); //事先计算好压缩后feasign的空间 bar << feas_len; bar.Resize(bar.Length() + feas_len); - rec.parse_feas(bar.Cursor()); - data.data.assign(bar.Buffer(), bar.Length()); + rec.serialize_to_compress_feas(bar.Finish() - feas_len); //直接在archive内部buffer进行压缩,避免不必要的拷贝 + data.data.assign(bar.Buffer(), bar.Length()); //TODO 这一步拷贝是否也能避免 return 0; } @@ -355,7 +355,6 @@ public: paddle::framework::BinaryArchive bar; bar.SetReadBuffer(const_cast(&data.data[0]), data.data.size(), nullptr); - bar >> show; bar >> clk; bar >> tags; @@ -363,13 +362,14 @@ public: bar >> sample_type; bar >> auc_category_info_map; bar >> feas_len; - parse_feas_to_ins(bar.Cursor(), feas_len, instance.features); + CHECK((bar.Finish() - bar.Cursor()) == feas_len); + deserialize_feas_to_ins(bar.Cursor(), feas_len, instance.features); return 0; } private: - void parse_feas_to_ins(char* buffer, uint32_t len, std::vector& ins) const { + void deserialize_feas_to_ins(char* buffer, uint32_t len, std::vector& ins) const { if (buffer == nullptr) { return ; } diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc old mode 100644 new mode 100755 diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc old mode 100644 new mode 100755 index cf9ac43b96b50ea4dfa0859d434f6a46992dadf5..71568eade83ca5bad02d72cc09aacc8a0ddd1c4e --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -76,7 +76,7 @@ int LearnerProcess::run() { uint64_t epoch_id = epoch_accessor->current_epoch_id(); environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, - "Resume training with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str()); + "Resume train with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str()); //判断是否先dump出base wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);