提交 6789e1e5 编写于 作者: Y yaopenghui

add cost monitor

上级 3c1f7791
...@@ -13,14 +13,8 @@ int AucMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerCont ...@@ -13,14 +13,8 @@ int AucMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerCont
_table_size = config["table_size"].as<int>(); _table_size = config["table_size"].as<int>();
} }
set_table_size(_table_size); set_table_size(_table_size);
_compute_interval = 3600;
if (config["compute_interval"]) { if (config["compute_interval"]) {
uint32_t interval = config["compute_interval"].as<uint32_t>(); _compute_interval = config["compute_interval"].as<uint32_t>();
if (interval != 3600 || interval != 86400) {
LOG(FATAL) << " AucMonitor config compute_interval just support hour: 3600 or day: 86400. ";
return -1;
}
_compute_interval = interval;
} }
} }
...@@ -80,8 +74,7 @@ std::string AucMonitor::format_result() { ...@@ -80,8 +74,7 @@ std::string AucMonitor::format_result() {
if (fabs(_predicted_ctr) > 1e-6) { if (fabs(_predicted_ctr) > 1e-6) {
copc = _actual_ctr / _predicted_ctr; copc = _actual_ctr / _predicted_ctr;
} }
char buf[10240]; return paddle::string::format_string("%s: AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "
snprintf(buf, 10240 * sizeof(char), "%s: AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "
"Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f", "Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f",
Monitor::_name.c_str(), Monitor::_name.c_str(),
_auc, _auc,
...@@ -92,8 +85,6 @@ std::string AucMonitor::format_result() { ...@@ -92,8 +85,6 @@ std::string AucMonitor::format_result() {
_predicted_ctr, _predicted_ctr,
copc, copc,
_size); _size);
return std::string(buf);
} }
void AucMonitor::add_unlocked(double pred, int label) { void AucMonitor::add_unlocked(double pred, int label) {
......
...@@ -6,14 +6,8 @@ namespace feed { ...@@ -6,14 +6,8 @@ namespace feed {
int CostMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context_ptr) { int CostMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context_ptr) {
Monitor::initialize(config, context_ptr); Monitor::initialize(config, context_ptr);
_compute_interval = 3600;
if (config["compute_interval"]) { if (config["compute_interval"]) {
uint32_t interval = config["compute_interval"].as<uint32_t>(); _compute_interval = config["compute_interval"].as<uint32_t>();
if (interval != 3600 || interval != 86400) {
LOG(FATAL) << " AucMonitor config compute_interval just support hour: 3600 or day: 86400. ";
return -1;
}
_compute_interval = interval;
} }
} }
...@@ -36,14 +30,6 @@ bool CostMonitor::need_compute_result(int epoch_id, EpochAccessor* accessor) { ...@@ -36,14 +30,6 @@ bool CostMonitor::need_compute_result(int epoch_id, EpochAccessor* accessor) {
return true; return true;
} }
std::string CostMonitor::format_result() {
char buf[1024];
snprintf(buf, 1024 * sizeof(char), "%s: Cost Time=%lu",
Monitor::_name.c_str(),
_avg_time_ms);
return std::string(buf);
}
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
} // namespace paddle } // namespace paddle
...@@ -10,7 +10,7 @@ namespace feed { ...@@ -10,7 +10,7 @@ namespace feed {
// cost time profile // cost time profile
class CostMonitor : public Monitor { class CostMonitor : public Monitor {
public: public:
CostMonitor() : _total_time_ms(0), _total_cnt(0), _avg_time_ms(0) {} CostMonitor() : _total_time_ms(0), _total_cnt(0), _avg_time_ms(0), _compute_interval(0) {}
virtual ~CostMonitor() {} virtual ~CostMonitor() {}
virtual int initialize(const YAML::Node& config, virtual int initialize(const YAML::Node& config,
...@@ -30,7 +30,10 @@ public: ...@@ -30,7 +30,10 @@ public:
_avg_time_ms = _total_time_ms / _total_cnt; _avg_time_ms = _total_time_ms / _total_cnt;
} }
//基于现有结果,输出格式化的统计信息 //基于现有结果,输出格式化的统计信息
virtual std::string format_result(); virtual std::string format_result() {
return paddle::string::format_string(
"Monitor %s: Cost Time=%lu", Monitor::_name.c_str(), _avg_time_ms);
}
virtual void reset() { virtual void reset() {
_total_time_ms = 0; _total_time_ms = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册