diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 8749011e61e1f29b483a82eafd9fcf8db84559d1..b2a2718bbc18665e8e5c63d93d2a9cf0d1a9164b 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -413,6 +413,38 @@ class BoxWrapper { std::vector> cmatch_rank_v; std::string cmatch_rank_varname_; }; + class MaskMetricMsg : public MetricMsg { + public: + MaskMetricMsg(const std::string& label_varname, + const std::string& pred_varname, int is_join, + const std::string& mask_varname, int bucket_size = 1000000) { + label_varname_ = label_varname; + pred_varname_ = pred_varname; + mask_varname_ = mask_varname; + is_join_ = is_join; + calculator = new BasicAucCalculator(); + calculator->init(bucket_size); + } + virtual ~MaskMetricMsg() {} + void add_data(const Scope* exe_scope) override { + std::vector label_data; + get_data(exe_scope, label_varname_, &label_data); + std::vector pred_data; + get_data(exe_scope, pred_varname_, &pred_data); + std::vector mask_data; + get_data(exe_scope, mask_varname_, &mask_data); + auto cal = GetCalculator(); + auto batch_size = label_data.size(); + for (size_t i = 0; i < batch_size; ++i) { + if (mask_data[i] == 1) { + cal->add_data(pred_data[i], label_data[i]); + } + } + } + + protected: + std::string mask_varname_; + }; const std::vector& GetMetricNameList() const { return metric_name_list_; } @@ -423,7 +455,8 @@ class BoxWrapper { void InitMetric(const std::string& method, const std::string& name, const std::string& label_varname, const std::string& pred_varname, - const std::string& cmatch_rank_varname, bool is_join, + const std::string& cmatch_rank_varname, + const std::string& mask_varname, bool is_join, const std::string& cmatch_rank_group, int bucket_size = 1000000) { if (method == "AucCalculator") { @@ -439,10 +472,14 @@ class BoxWrapper { name, new CmatchRankMetricMsg(label_varname, pred_varname, is_join ? 1 : 0, cmatch_rank_group, cmatch_rank_varname, bucket_size)); + } else if (method == "MaskAucCalculator") { + metric_lists_.emplace( + name, new MaskMetricMsg(label_varname, pred_varname, is_join ? 1 : 0, + mask_varname, bucket_size)); } else { PADDLE_THROW(platform::errors::Unimplemented( - "PaddleBox only support AucCalculator, MultiTaskAucCalculator and " - "CmatchRankAucCalculator")); + "PaddleBox only support AucCalculator, MultiTaskAucCalculator " + "CmatchRankAucCalculator and MaskAucCalculator")); } metric_name_list_.emplace_back(name); }