未验证 提交 1c224e26 编写于 作者: H hutuxian 提交者: GitHub

support CMatchAuc (#24990)

Support CMatchAucCalculator based on CMatchRankAucCalculator with a new parameter ignore_rank
上级 28d074e9
......@@ -328,6 +328,7 @@ void DatasetImpl<T>::ReleaseMemory() {
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
input_records_.clear();
std::vector<T>().swap(input_records_);
std::vector<T>().swap(slots_shuffle_original_data_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
}
......
......@@ -637,14 +637,19 @@ class BoxWrapper {
const std::string& pred_varname, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
int bucket_size = 1000000) {
bool ignore_rank = false, int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
metric_phase_ = metric_phase;
ignore_rank_ = ignore_rank;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
if (ignore_rank) { // CmatchAUC
cmatch_rank_v.emplace_back(atoi(cmatch_rank.c_str()), 0);
continue;
}
const std::vector<std::string>& cur_cmatch_rank =
string::split_string(cmatch_rank, "_");
PADDLE_ENFORCE_EQ(
......@@ -678,7 +683,13 @@ class BoxWrapper {
for (size_t i = 0; i < batch_size; ++i) {
const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]);
for (size_t j = 0; j < cmatch_rank_v.size(); ++j) {
if (cmatch_rank_v[j] == cur_cmatch_rank) {
bool is_matched = false;
if (ignore_rank_) {
is_matched = cmatch_rank_v[j].first == cur_cmatch_rank.first;
} else {
is_matched = cmatch_rank_v[j] == cur_cmatch_rank;
}
if (is_matched) {
cal->add_data(pred_data[i], label_data[i]);
break;
}
......@@ -689,6 +700,7 @@ class BoxWrapper {
protected:
std::vector<std::pair<int, int>> cmatch_rank_v;
std::string cmatch_rank_varname_;
bool ignore_rank_;
};
class MaskMetricMsg : public MetricMsg {
public:
......@@ -757,7 +769,7 @@ class BoxWrapper {
const std::string& pred_varname,
const std::string& cmatch_rank_varname,
const std::string& mask_varname, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_group, bool ignore_rank,
int bucket_size = 1000000) {
if (method == "AucCalculator") {
metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname,
......@@ -768,10 +780,10 @@ class BoxWrapper {
metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size));
} else if (method == "CmatchRankAucCalculator") {
metric_lists_.emplace(
name, new CmatchRankMetricMsg(label_varname, pred_varname,
metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size));
metric_lists_.emplace(name, new CmatchRankMetricMsg(
label_varname, pred_varname, metric_phase,
cmatch_rank_group, cmatch_rank_varname,
ignore_rank, bucket_size));
} else if (method == "MaskAucCalculator") {
metric_lists_.emplace(
name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
......@@ -955,9 +967,6 @@ class BoxHelper {
new_input_channel->Close();
dynamic_cast<MultiSlotDataset*>(dataset_)->SetInputChannel(
new_input_channel);
if (dataset_->EnablePvMerge()) {
dataset_->PreprocessInstance();
}
#endif
}
#ifdef PADDLE_WITH_BOX_PS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册