未验证 提交 ae3bb16d 编写于 作者: D danleifeng 提交者: GitHub

add MaskAucCalculator in paddlebox (#23157)

* add maskauc in paddlebox; test=develop
上级 6af480ca
...@@ -413,6 +413,38 @@ class BoxWrapper { ...@@ -413,6 +413,38 @@ class BoxWrapper {
std::vector<std::pair<int, int>> cmatch_rank_v; std::vector<std::pair<int, int>> cmatch_rank_v;
std::string cmatch_rank_varname_; std::string cmatch_rank_varname_;
}; };
class MaskMetricMsg : public MetricMsg {
public:
MaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int is_join,
const std::string& mask_varname, int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
mask_varname_ = mask_varname;
is_join_ = is_join;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
virtual ~MaskMetricMsg() {}
void add_data(const Scope* exe_scope) override {
std::vector<int64_t> label_data;
get_data<int64_t>(exe_scope, label_varname_, &label_data);
std::vector<float> pred_data;
get_data<float>(exe_scope, pred_varname_, &pred_data);
std::vector<int64_t> mask_data;
get_data<int64_t>(exe_scope, mask_varname_, &mask_data);
auto cal = GetCalculator();
auto batch_size = label_data.size();
for (size_t i = 0; i < batch_size; ++i) {
if (mask_data[i] == 1) {
cal->add_data(pred_data[i], label_data[i]);
}
}
}
protected:
std::string mask_varname_;
};
const std::vector<std::string>& GetMetricNameList() const { const std::vector<std::string>& GetMetricNameList() const {
return metric_name_list_; return metric_name_list_;
} }
...@@ -423,7 +455,8 @@ class BoxWrapper { ...@@ -423,7 +455,8 @@ class BoxWrapper {
void InitMetric(const std::string& method, const std::string& name, void InitMetric(const std::string& method, const std::string& name,
const std::string& label_varname, const std::string& label_varname,
const std::string& pred_varname, const std::string& pred_varname,
const std::string& cmatch_rank_varname, bool is_join, const std::string& cmatch_rank_varname,
const std::string& mask_varname, bool is_join,
const std::string& cmatch_rank_group, const std::string& cmatch_rank_group,
int bucket_size = 1000000) { int bucket_size = 1000000) {
if (method == "AucCalculator") { if (method == "AucCalculator") {
...@@ -439,10 +472,14 @@ class BoxWrapper { ...@@ -439,10 +472,14 @@ class BoxWrapper {
name, new CmatchRankMetricMsg(label_varname, pred_varname, name, new CmatchRankMetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, cmatch_rank_group, is_join ? 1 : 0, cmatch_rank_group,
cmatch_rank_varname, bucket_size)); cmatch_rank_varname, bucket_size));
} else if (method == "MaskAucCalculator") {
metric_lists_.emplace(
name, new MaskMetricMsg(label_varname, pred_varname, is_join ? 1 : 0,
mask_varname, bucket_size));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"PaddleBox only support AucCalculator, MultiTaskAucCalculator and " "PaddleBox only support AucCalculator, MultiTaskAucCalculator "
"CmatchRankAucCalculator")); "CmatchRankAucCalculator and MaskAucCalculator"));
} }
metric_name_list_.emplace_back(name); metric_name_list_.emplace_back(name);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册