提交 b380a55f 编写于 作者: X xiexionghang

fix code style

......@@ -23,6 +23,10 @@ public:
// 执行训练
virtual int run(::paddle::framework::Scope* scope) = 0;
// cost time millisecond
virtual uint64_t epoch_cost() const {
return 0;
}
};
REGIST_REGISTERER(Executor);
......
......@@ -13,11 +13,8 @@ int AucMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerCont
_table_size = config["table_size"].as<int>();
}
set_table_size(_table_size);
_compute_interval = 3600;
if (config["compute_interval"]) {
_compute_interval = config["compute_interval"].as<uint32_t>();
CHECK(_compute_interval % 60 == 0);
}
_compute_interval = config["compute_interval"].as<uint32_t>();
CHECK(_compute_interval % 60 == 0);
return 0;
}
......@@ -36,7 +33,6 @@ bool AucMonitor::need_compute_result(int epoch_id) {
uint64_t epoch_time = _epoch_accessor->epoch_timestamp(epoch_id);
return epoch_time % _compute_interval == 0;
}
void AucMonitor::compute_result() {
auto* environment = Monitor::_context_ptr->environment.get();
double* table[2] = {&_table[0][0], &_table[1][0]};
......@@ -59,7 +55,6 @@ void AucMonitor::compute_result() {
ReduceOperator::SUM, EnvironmentRole::WORKER) / (fp + tp);
_rmse = sqrt(environment->all_reduce(_local_sqrerr,
ReduceOperator::SUM, EnvironmentRole::WORKER) / (fp + tp));
_rmse = sqrt(_rmse / (fp + tp));
_actual_ctr = tp / (fp + tp);
_predicted_ctr = environment->all_reduce(_local_pred,
ReduceOperator::SUM, EnvironmentRole::WORKER) / (fp + tp);
......@@ -72,8 +67,7 @@ std::string AucMonitor::format_result() {
if (fabs(_predicted_ctr) > 1e-6) {
copc = _actual_ctr / _predicted_ctr;
}
char buf[10240];
snprintf(buf, 10240 * sizeof(char), "AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "
return paddle::string::format_string("AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "
"Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f",
_auc,
_bucket_error,
......@@ -83,8 +77,6 @@ std::string AucMonitor::format_result() {
_predicted_ctr,
copc,
_size);
return std::string(buf);
}
void AucMonitor::add_unlocked(double pred, int label) {
......
#include "paddle/fluid/train/custom_trainer/feed/monitor/cost_monitor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
int CostMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context_ptr) {
Monitor::initialize(config, context_ptr);
if (config["compute_interval"]) {
_compute_interval = config["compute_interval"].as<uint32_t>();
}
}
void CostMonitor::add_data(int epoch_id,
const MultiThreadExecutor* executor,
SampleInstance* samples, size_t num) {
CHECK(executor != nullptr);
//TODO use paddle time
_total_time_ms += 1;
_total_cnt ++;
}
bool CostMonitor::need_compute_result(int epoch_id) {
uint64_t epoch_time = _epoch_accessor->epoch_timestamp(epoch_id);
return epoch_time % _compute_interval == 0;
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
#pragma once
#include <string>
#include <cmath> //std::lround
#include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
// cost time profile
class CostMonitor : public Monitor {
public:
CostMonitor() : _total_time_ms(0), _total_cnt(0), _avg_time_ms(0), _compute_interval(0) {}
virtual ~CostMonitor() {}
virtual int initialize(const YAML::Node& config,
std::shared_ptr<TrainerContext> context_ptr) override;
//添加一项记录,统计内容Monitor自行从Executor按需获取
virtual void add_data(int epoch_id,
const MultiThreadExecutor* executor,
SampleInstance* samples,
size_t num);
//是否开始结果统计
virtual bool need_compute_result(int epoch_id);
//统计当前结果
virtual void compute_result() {
CHECK(_total_cnt != 0);
_avg_time_ms = _total_time_ms / _total_cnt;
}
//基于现有结果,输出格式化的统计信息
virtual std::string format_result() {
return paddle::string::format_string(
"Monitor %s: Cost Time=%lu", Monitor::_name.c_str(), _avg_time_ms);
}
virtual void reset() {
_total_time_ms = 0;
_total_cnt = 0;
_avg_time_ms = 0;
}
protected:
std::string _name;
private:
uint64_t _total_time_ms;
uint64_t _total_cnt;
uint64_t _avg_time_ms;
uint32_t _compute_interval;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册