未验证 提交 1c224e26 编写于 作者: H hutuxian 提交者: GitHub

support CMatchAuc (#24990)

Support CMatchAucCalculator based on CMatchRankAucCalculator with a new parameter ignore_rank
上级 28d074e9
...@@ -328,6 +328,7 @@ void DatasetImpl<T>::ReleaseMemory() { ...@@ -328,6 +328,7 @@ void DatasetImpl<T>::ReleaseMemory() {
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_); std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
input_records_.clear(); input_records_.clear();
std::vector<T>().swap(input_records_); std::vector<T>().swap(input_records_);
std::vector<T>().swap(slots_shuffle_original_data_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end"; VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
} }
......
...@@ -637,14 +637,19 @@ class BoxWrapper { ...@@ -637,14 +637,19 @@ class BoxWrapper {
const std::string& pred_varname, int metric_phase, const std::string& pred_varname, int metric_phase,
const std::string& cmatch_rank_group, const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname, const std::string& cmatch_rank_varname,
int bucket_size = 1000000) { bool ignore_rank = false, int bucket_size = 1000000) {
label_varname_ = label_varname; label_varname_ = label_varname;
pred_varname_ = pred_varname; pred_varname_ = pred_varname;
cmatch_rank_varname_ = cmatch_rank_varname; cmatch_rank_varname_ = cmatch_rank_varname;
metric_phase_ = metric_phase; metric_phase_ = metric_phase;
ignore_rank_ = ignore_rank;
calculator = new BasicAucCalculator(); calculator = new BasicAucCalculator();
calculator->init(bucket_size); calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) { for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
if (ignore_rank) { // CmatchAUC
cmatch_rank_v.emplace_back(atoi(cmatch_rank.c_str()), 0);
continue;
}
const std::vector<std::string>& cur_cmatch_rank = const std::vector<std::string>& cur_cmatch_rank =
string::split_string(cmatch_rank, "_"); string::split_string(cmatch_rank, "_");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -678,7 +683,13 @@ class BoxWrapper { ...@@ -678,7 +683,13 @@ class BoxWrapper {
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]); const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]);
for (size_t j = 0; j < cmatch_rank_v.size(); ++j) { for (size_t j = 0; j < cmatch_rank_v.size(); ++j) {
if (cmatch_rank_v[j] == cur_cmatch_rank) { bool is_matched = false;
if (ignore_rank_) {
is_matched = cmatch_rank_v[j].first == cur_cmatch_rank.first;
} else {
is_matched = cmatch_rank_v[j] == cur_cmatch_rank;
}
if (is_matched) {
cal->add_data(pred_data[i], label_data[i]); cal->add_data(pred_data[i], label_data[i]);
break; break;
} }
...@@ -689,6 +700,7 @@ class BoxWrapper { ...@@ -689,6 +700,7 @@ class BoxWrapper {
protected: protected:
std::vector<std::pair<int, int>> cmatch_rank_v; std::vector<std::pair<int, int>> cmatch_rank_v;
std::string cmatch_rank_varname_; std::string cmatch_rank_varname_;
bool ignore_rank_;
}; };
class MaskMetricMsg : public MetricMsg { class MaskMetricMsg : public MetricMsg {
public: public:
...@@ -757,7 +769,7 @@ class BoxWrapper { ...@@ -757,7 +769,7 @@ class BoxWrapper {
const std::string& pred_varname, const std::string& pred_varname,
const std::string& cmatch_rank_varname, const std::string& cmatch_rank_varname,
const std::string& mask_varname, int metric_phase, const std::string& mask_varname, int metric_phase,
const std::string& cmatch_rank_group, const std::string& cmatch_rank_group, bool ignore_rank,
int bucket_size = 1000000) { int bucket_size = 1000000) {
if (method == "AucCalculator") { if (method == "AucCalculator") {
metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname, metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname,
...@@ -768,10 +780,10 @@ class BoxWrapper { ...@@ -768,10 +780,10 @@ class BoxWrapper {
metric_phase, cmatch_rank_group, metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size)); cmatch_rank_varname, bucket_size));
} else if (method == "CmatchRankAucCalculator") { } else if (method == "CmatchRankAucCalculator") {
metric_lists_.emplace( metric_lists_.emplace(name, new CmatchRankMetricMsg(
name, new CmatchRankMetricMsg(label_varname, pred_varname, label_varname, pred_varname, metric_phase,
metric_phase, cmatch_rank_group, cmatch_rank_group, cmatch_rank_varname,
cmatch_rank_varname, bucket_size)); ignore_rank, bucket_size));
} else if (method == "MaskAucCalculator") { } else if (method == "MaskAucCalculator") {
metric_lists_.emplace( metric_lists_.emplace(
name, new MaskMetricMsg(label_varname, pred_varname, metric_phase, name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
...@@ -955,9 +967,6 @@ class BoxHelper { ...@@ -955,9 +967,6 @@ class BoxHelper {
new_input_channel->Close(); new_input_channel->Close();
dynamic_cast<MultiSlotDataset*>(dataset_)->SetInputChannel( dynamic_cast<MultiSlotDataset*>(dataset_)->SetInputChannel(
new_input_channel); new_input_channel);
if (dataset_->EnablePvMerge()) {
dataset_->PreprocessInstance();
}
#endif #endif
} }
#ifdef PADDLE_WITH_BOX_PS #ifdef PADDLE_WITH_BOX_PS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册