From 6789e1e5df58ed5ec1839dd346e8e137524ccdd8 Mon Sep 17 00:00:00 2001 From: yaopenghui Date: Tue, 27 Aug 2019 14:17:08 +0800 Subject: [PATCH] add cost monitor --- .../custom_trainer/feed/monitor/auc_monitor.cc | 13 ++----------- .../custom_trainer/feed/monitor/cost_monitor.cc | 16 +--------------- .../custom_trainer/feed/monitor/cost_monitor.h | 7 +++++-- 3 files changed, 8 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.cc b/paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.cc index 9f20e3b5..17eefce9 100644 --- a/paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.cc +++ b/paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.cc @@ -13,14 +13,8 @@ int AucMonitor::initialize(const YAML::Node& config, std::shared_ptr(); } set_table_size(_table_size); - _compute_interval = 3600; if (config["compute_interval"]) { - uint32_t interval = config["compute_interval"].as(); - if (interval != 3600 || interval != 86400) { - LOG(FATAL) << " AucMonitor config compute_interval just support hour: 3600 or day: 86400. "; - return -1; - } - _compute_interval = interval; + _compute_interval = config["compute_interval"].as(); } } @@ -80,8 +74,7 @@ std::string AucMonitor::format_result() { if (fabs(_predicted_ctr) > 1e-6) { copc = _actual_ctr / _predicted_ctr; } - char buf[10240]; - snprintf(buf, 10240 * sizeof(char), "%s: AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f " + return paddle::string::format_string("%s: AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f " "Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f", Monitor::_name.c_str(), _auc, @@ -92,8 +85,6 @@ std::string AucMonitor::format_result() { _predicted_ctr, copc, _size); - - return std::string(buf); } void AucMonitor::add_unlocked(double pred, int label) { diff --git a/paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.cc b/paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.cc index f4b48f28..0f8b17b7 100644 --- a/paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.cc +++ b/paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.cc @@ -6,14 +6,8 @@ namespace feed { int CostMonitor::initialize(const YAML::Node& config, std::shared_ptr context_ptr) { Monitor::initialize(config, context_ptr); - _compute_interval = 3600; if (config["compute_interval"]) { - uint32_t interval = config["compute_interval"].as(); - if (interval != 3600 || interval != 86400) { - LOG(FATAL) << " AucMonitor config compute_interval just support hour: 3600 or day: 86400. "; - return -1; - } - _compute_interval = interval; + _compute_interval = config["compute_interval"].as(); } } @@ -36,14 +30,6 @@ bool CostMonitor::need_compute_result(int epoch_id, EpochAccessor* accessor) { return true; } -std::string CostMonitor::format_result() { - char buf[1024]; - snprintf(buf, 1024 * sizeof(char), "%s: Cost Time=%lu", - Monitor::_name.c_str(), - _avg_time_ms); - return std::string(buf); -} - } // namespace feed } // namespace custom_trainer } // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.h b/paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.h index 56a0ecfb..22815355 100644 --- a/paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.h +++ b/paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.h @@ -10,7 +10,7 @@ namespace feed { // cost time profile class CostMonitor : public Monitor { public: - CostMonitor() : _total_time_ms(0), _total_cnt(0), _avg_time_ms(0) {} + CostMonitor() : _total_time_ms(0), _total_cnt(0), _avg_time_ms(0), _compute_interval(0) {} virtual ~CostMonitor() {} virtual int initialize(const YAML::Node& config, @@ -30,7 +30,10 @@ public: _avg_time_ms = _total_time_ms / _total_cnt; } //基于现有结果,输出格式化的统计信息 - virtual std::string format_result(); + virtual std::string format_result() { + return paddle::string::format_string( + "Monitor %s: Cost Time=%lu", Monitor::_name.c_str(), _avg_time_ms); + } virtual void reset() { _total_time_ms = 0; -- GitLab