提交 f0d571da 编写于 作者: Y yaopenghui

add cost monitor

上级 9da0a7f4
......@@ -37,6 +37,10 @@ public:
virtual bool is_dump_all_model() {
return false;
}
// cost time millisecond
virtual uint64_t epoch_cost() const {
return 0;
}
protected:
::paddle::framework::Scope _scope;
};
......
......@@ -44,6 +44,7 @@ void AucMonitor::add_data(int epoch_id, const Executor* executor, SampleInstance
bool AucMonitor::need_compute_result(int epoch_id, EpochAccessor* accessor) {
CHECK(accessor != nullptr);
uint64_t epoch_time = accessor->epoch_timestamp(epoch_id);
CHECK(_compute_interval != 0);
if (epoch_time % _compute_interval != 0) {
return false;
}
......@@ -68,7 +69,6 @@ void AucMonitor::compute_result() {
_auc = area / (fp * tp);
_mae = Monitor::_context_ptr->environment->all_reduce_ele(_local_abserr) / (fp + tp);
_rmse = sqrt(Monitor::_context_ptr->environment->all_reduce_ele(_local_sqrerr) / (fp + tp));
_rmse = sqrt(_rmse / (fp + tp));
_actual_ctr = tp / (fp + tp);
_predicted_ctr = Monitor::_context_ptr->environment->all_reduce_ele(_local_pred) / (fp + tp);
_size = fp + tp;
......
#include "paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
int CostMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context_ptr) {
Monitor::initialize(config, context_ptr);
_compute_interval = 3600;
if (config["compute_interval"]) {
uint32_t interval = config["compute_interval"].as<uint32_t>();
if (interval != 3600 || interval != 86400) {
LOG(FATAL) << " AucMonitor config compute_interval just support hour: 3600 or day: 86400. ";
return -1;
}
_compute_interval = interval;
}
}
void CostMonitor::add_data(int epoch_id,
const Executor* executor,
SampleInstance* instance,
size_t num) {
CHECK(executor != nullptr);
_total_time_ms += executor->epoch_cost();
_total_cnt ++;
}
bool CostMonitor::need_compute_result(int epoch_id, EpochAccessor* accessor) {
CHECK(accessor != nullptr);
uint64_t epoch_time = accessor->epoch_timestamp(epoch_id);
CHECK(_compute_interval != 0);
if (epoch_time % _compute_interval != 0) {
return false;
}
return true;
}
std::string CostMonitor::format_result() {
char buf[1024];
snprintf(buf, 1024 * sizeof(char), "%s: Cost Time=%lu",
Monitor::_name.c_str(),
_avg_time_ms);
return std::string(buf);
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
#pragma once
#include <string>
#include <cmath> //std::lround
#include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
// cost time profile
class CostMonitor : public Monitor {
public:
CostMonitor() : _total_time_ms(0), _total_cnt(0), _avg_time_ms(0) {}
virtual ~CostMonitor() {}
virtual int initialize(const YAML::Node& config,
std::shared_ptr<TrainerContext> context_ptr) override;
//添加一项记录,统计内容Monitor自行从Executor按需获取
virtual void add_data(int epoch_id,
const Executor* executor,
SampleInstance* instance,
size_t num);
//是否开始结果统计
virtual bool need_compute_result(int epoch_id, EpochAccessor* accessor);
//统计当前结果
virtual void compute_result() {
CHECK(_total_cnt != 0);
_avg_time_ms = _total_time_ms / _total_cnt;
}
//基于现有结果,输出格式化的统计信息
virtual std::string format_result();
virtual void reset() {
_total_time_ms = 0;
_total_cnt = 0;
_avg_time_ms = 0;
}
protected:
std::string _name;
private:
uint64_t _total_time_ms;
uint64_t _total_cnt;
uint64_t _avg_time_ms;
uint32_t _compute_interval;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册