提交 b380a55f 编写于 作者: X xiexionghang

fix code style

...@@ -23,6 +23,10 @@ public: ...@@ -23,6 +23,10 @@ public:
// 执行训练 // 执行训练
virtual int run(::paddle::framework::Scope* scope) = 0; virtual int run(::paddle::framework::Scope* scope) = 0;
// cost time millisecond
virtual uint64_t epoch_cost() const {
return 0;
}
}; };
REGIST_REGISTERER(Executor); REGIST_REGISTERER(Executor);
......
...@@ -13,11 +13,8 @@ int AucMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerCont ...@@ -13,11 +13,8 @@ int AucMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerCont
_table_size = config["table_size"].as<int>(); _table_size = config["table_size"].as<int>();
} }
set_table_size(_table_size); set_table_size(_table_size);
_compute_interval = 3600;
if (config["compute_interval"]) {
_compute_interval = config["compute_interval"].as<uint32_t>(); _compute_interval = config["compute_interval"].as<uint32_t>();
CHECK(_compute_interval % 60 == 0); CHECK(_compute_interval % 60 == 0);
}
return 0; return 0;
} }
...@@ -36,7 +33,6 @@ bool AucMonitor::need_compute_result(int epoch_id) { ...@@ -36,7 +33,6 @@ bool AucMonitor::need_compute_result(int epoch_id) {
uint64_t epoch_time = _epoch_accessor->epoch_timestamp(epoch_id); uint64_t epoch_time = _epoch_accessor->epoch_timestamp(epoch_id);
return epoch_time % _compute_interval == 0; return epoch_time % _compute_interval == 0;
} }
void AucMonitor::compute_result() { void AucMonitor::compute_result() {
auto* environment = Monitor::_context_ptr->environment.get(); auto* environment = Monitor::_context_ptr->environment.get();
double* table[2] = {&_table[0][0], &_table[1][0]}; double* table[2] = {&_table[0][0], &_table[1][0]};
...@@ -59,7 +55,6 @@ void AucMonitor::compute_result() { ...@@ -59,7 +55,6 @@ void AucMonitor::compute_result() {
ReduceOperator::SUM, EnvironmentRole::WORKER) / (fp + tp); ReduceOperator::SUM, EnvironmentRole::WORKER) / (fp + tp);
_rmse = sqrt(environment->all_reduce(_local_sqrerr, _rmse = sqrt(environment->all_reduce(_local_sqrerr,
ReduceOperator::SUM, EnvironmentRole::WORKER) / (fp + tp)); ReduceOperator::SUM, EnvironmentRole::WORKER) / (fp + tp));
_rmse = sqrt(_rmse / (fp + tp));
_actual_ctr = tp / (fp + tp); _actual_ctr = tp / (fp + tp);
_predicted_ctr = environment->all_reduce(_local_pred, _predicted_ctr = environment->all_reduce(_local_pred,
ReduceOperator::SUM, EnvironmentRole::WORKER) / (fp + tp); ReduceOperator::SUM, EnvironmentRole::WORKER) / (fp + tp);
...@@ -72,8 +67,7 @@ std::string AucMonitor::format_result() { ...@@ -72,8 +67,7 @@ std::string AucMonitor::format_result() {
if (fabs(_predicted_ctr) > 1e-6) { if (fabs(_predicted_ctr) > 1e-6) {
copc = _actual_ctr / _predicted_ctr; copc = _actual_ctr / _predicted_ctr;
} }
char buf[10240]; return paddle::string::format_string("AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "
snprintf(buf, 10240 * sizeof(char), "AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "
"Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f", "Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f",
_auc, _auc,
_bucket_error, _bucket_error,
...@@ -83,8 +77,6 @@ std::string AucMonitor::format_result() { ...@@ -83,8 +77,6 @@ std::string AucMonitor::format_result() {
_predicted_ctr, _predicted_ctr,
copc, copc,
_size); _size);
return std::string(buf);
} }
void AucMonitor::add_unlocked(double pred, int label) { void AucMonitor::add_unlocked(double pred, int label) {
......
#include "paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
int CostMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context_ptr) {
Monitor::initialize(config, context_ptr);
if (config["compute_interval"]) {
_compute_interval = config["compute_interval"].as<uint32_t>();
}
}
void CostMonitor::add_data(int epoch_id,
const MultiThreadExecutor* executor,
SampleInstance* samples, size_t num) {
CHECK(executor != nullptr);
//TODO use paddle time
_total_time_ms += 1;
_total_cnt ++;
}
bool CostMonitor::need_compute_result(int epoch_id) {
uint64_t epoch_time = _epoch_accessor->epoch_timestamp(epoch_id);
return epoch_time % _compute_interval == 0;
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
#pragma once
#include <string>
#include <cmath> //std::lround
#include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
// cost time profile
class CostMonitor : public Monitor {
public:
CostMonitor() : _total_time_ms(0), _total_cnt(0), _avg_time_ms(0), _compute_interval(0) {}
virtual ~CostMonitor() {}
virtual int initialize(const YAML::Node& config,
std::shared_ptr<TrainerContext> context_ptr) override;
//添加一项记录,统计内容Monitor自行从Executor按需获取
virtual void add_data(int epoch_id,
const MultiThreadExecutor* executor,
SampleInstance* samples,
size_t num);
//是否开始结果统计
virtual bool need_compute_result(int epoch_id);
//统计当前结果
virtual void compute_result() {
CHECK(_total_cnt != 0);
_avg_time_ms = _total_time_ms / _total_cnt;
}
//基于现有结果,输出格式化的统计信息
virtual std::string format_result() {
return paddle::string::format_string(
"Monitor %s: Cost Time=%lu", Monitor::_name.c_str(), _avg_time_ms);
}
virtual void reset() {
_total_time_ms = 0;
_total_cnt = 0;
_avg_time_ms = 0;
}
protected:
std::string _name;
private:
uint64_t _total_time_ms;
uint64_t _total_cnt;
uint64_t _avg_time_ms;
uint32_t _compute_interval;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册