From 1c224e26af86ee60eefd3656fc1fb9dc807577c8 Mon Sep 17 00:00:00 2001 From: hutuxian Date: Wed, 10 Jun 2020 15:52:25 +0800 Subject: [PATCH] support CMatchAuc (#24990) Support CMatchAucCalculator based on CMatchRankAucCalculator with a new parameter ignore_rank --- paddle/fluid/framework/data_set.cc | 1 + paddle/fluid/framework/fleet/box_wrapper.h | 29 ++++++++++++++-------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 712592357cb..1ed6569f713 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -328,6 +328,7 @@ void DatasetImpl::ReleaseMemory() { std::vector>().swap(readers_); input_records_.clear(); std::vector().swap(input_records_); + std::vector().swap(slots_shuffle_original_data_); VLOG(3) << "DatasetImpl::ReleaseMemory() end"; } diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 27eb0d68eb4..399ee744ea9 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -637,14 +637,19 @@ class BoxWrapper { const std::string& pred_varname, int metric_phase, const std::string& cmatch_rank_group, const std::string& cmatch_rank_varname, - int bucket_size = 1000000) { + bool ignore_rank = false, int bucket_size = 1000000) { label_varname_ = label_varname; pred_varname_ = pred_varname; cmatch_rank_varname_ = cmatch_rank_varname; metric_phase_ = metric_phase; + ignore_rank_ = ignore_rank; calculator = new BasicAucCalculator(); calculator->init(bucket_size); for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) { + if (ignore_rank) { // CmatchAUC + cmatch_rank_v.emplace_back(atoi(cmatch_rank.c_str()), 0); + continue; + } const std::vector& cur_cmatch_rank = string::split_string(cmatch_rank, "_"); PADDLE_ENFORCE_EQ( @@ -678,7 +683,13 @@ class BoxWrapper { for (size_t i = 0; i < batch_size; ++i) { const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]); for (size_t j = 0; j < cmatch_rank_v.size(); ++j) { - if (cmatch_rank_v[j] == cur_cmatch_rank) { + bool is_matched = false; + if (ignore_rank_) { + is_matched = cmatch_rank_v[j].first == cur_cmatch_rank.first; + } else { + is_matched = cmatch_rank_v[j] == cur_cmatch_rank; + } + if (is_matched) { cal->add_data(pred_data[i], label_data[i]); break; } @@ -689,6 +700,7 @@ class BoxWrapper { protected: std::vector> cmatch_rank_v; std::string cmatch_rank_varname_; + bool ignore_rank_; }; class MaskMetricMsg : public MetricMsg { public: @@ -757,7 +769,7 @@ class BoxWrapper { const std::string& pred_varname, const std::string& cmatch_rank_varname, const std::string& mask_varname, int metric_phase, - const std::string& cmatch_rank_group, + const std::string& cmatch_rank_group, bool ignore_rank, int bucket_size = 1000000) { if (method == "AucCalculator") { metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname, @@ -768,10 +780,10 @@ class BoxWrapper { metric_phase, cmatch_rank_group, cmatch_rank_varname, bucket_size)); } else if (method == "CmatchRankAucCalculator") { - metric_lists_.emplace( - name, new CmatchRankMetricMsg(label_varname, pred_varname, - metric_phase, cmatch_rank_group, - cmatch_rank_varname, bucket_size)); + metric_lists_.emplace(name, new CmatchRankMetricMsg( + label_varname, pred_varname, metric_phase, + cmatch_rank_group, cmatch_rank_varname, + ignore_rank, bucket_size)); } else if (method == "MaskAucCalculator") { metric_lists_.emplace( name, new MaskMetricMsg(label_varname, pred_varname, metric_phase, @@ -955,9 +967,6 @@ class BoxHelper { new_input_channel->Close(); dynamic_cast(dataset_)->SetInputChannel( new_input_channel); - if (dataset_->EnablePvMerge()) { - dataset_->PreprocessInstance(); - } #endif } #ifdef PADDLE_WITH_BOX_PS -- GitLab