提交 f7d671a6 编写于 作者: Y yaopenghui

add auc monitor

上级 46be5864
...@@ -5,6 +5,58 @@ namespace paddle { ...@@ -5,6 +5,58 @@ namespace paddle {
namespace custom_trainer { namespace custom_trainer {
namespace feed { namespace feed {
template<class T>
struct mpi_type_trait {
};
template<>
struct mpi_type_trait<double> {
static MPI_Datatype type() {
return MPI_DOUBLE;
}
};
template<>
struct mpi_type_trait<float> {
static MPI_Datatype type() {
return MPI_FLOAT;
}
};
template<>
struct mpi_type_trait<int32_t> {
static MPI_Datatype type() {
return MPI_INT;
}
};
template<>
struct mpi_type_trait<uint32_t> {
static MPI_Datatype type() {
return MPI_UNSIGNED;
}
};
template<>
struct mpi_type_trait<int64_t> {
static MPI_Datatype type() {
return MPI_LONG_LONG;
}
};
template<>
struct mpi_type_trait<uint64_t> {
static MPI_Datatype type() {
return MPI_UNSIGNED_LONG_LONG;
}
};
template<>
struct mpi_type_trait<long long> {
static MPI_Datatype type() {
return MPI_LONG_LONG;
}
};
template<>
struct mpi_type_trait<unsigned long long> {
static MPI_Datatype type() {
return MPI_UNSIGNED_LONG_LONG;
}
};
RuntimeEnvironment::RuntimeEnvironment() {} RuntimeEnvironment::RuntimeEnvironment() {}
RuntimeEnvironment::~RuntimeEnvironment() {} RuntimeEnvironment::~RuntimeEnvironment() {}
bool RuntimeEnvironment::is_master_node(EnvironmentRole role) { bool RuntimeEnvironment::is_master_node(EnvironmentRole role) {
...@@ -79,6 +131,15 @@ public: ...@@ -79,6 +131,15 @@ public:
MPI_Bcast(ar.Buffer(), len, MPI_BYTE, root_id, node_info.mpi_comm); MPI_Bcast(ar.Buffer(), len, MPI_BYTE, root_id, node_info.mpi_comm);
} }
virtual double all_reduce_ele(double x) {
double tot;
MPI_Allreduce(&x, &tot, 1, mpi_type_trait<double>::type(), MPI_SUM, MPI_COMM_WORLD);
return tot;
}
virtual void all_reduce_arr(double* x, int n) {
MPI_Allreduce(MPI_IN_PLACE, x, n, mpi_type_trait<double>::type(), MPI_SUM, MPI_COMM_WORLD);
}
protected: protected:
virtual void print_log(EnvironmentRole role, EnvironmentLogType type, virtual void print_log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const std::string& log_str) { EnvironmentLogLevel level, const std::string& log_str) {
...@@ -123,6 +184,12 @@ public: ...@@ -123,6 +184,12 @@ public:
virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) { virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) {
return; return;
} }
virtual double all_reduce_ele(double x) {
return x;
}
virtual void all_reduce_arr(double* x, int n) {
return;
}
protected: protected:
virtual void print_log(EnvironmentRole role, EnvironmentLogType type, virtual void print_log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const std::string& log_str) { EnvironmentLogLevel level, const std::string& log_str) {
......
...@@ -67,6 +67,10 @@ public: ...@@ -67,6 +67,10 @@ public:
virtual void barrier(EnvironmentRole role) = 0; virtual void barrier(EnvironmentRole role) = 0;
//bcast 广播 //bcast 广播
virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) = 0; virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) = 0;
//all_reduce sum element 规约元素
virtual double all_reduce_ele(double x) = 0;
//all_reduce sum array 规约数组
virtual void all_reduce_arr(double* x, int n) = 0;
//接口只允许在主线程调用 End //接口只允许在主线程调用 End
protected: protected:
virtual void print_log(EnvironmentRole role, EnvironmentLogType type, virtual void print_log(EnvironmentRole role, EnvironmentLogType type,
......
...@@ -51,6 +51,7 @@ private: ...@@ -51,6 +51,7 @@ private:
struct SampleInstance { struct SampleInstance {
std::string id; std::string id;
std::vector<float> predicts;
std::vector<float> labels; std::vector<float> labels;
std::vector<FeatureItem> features; std::vector<FeatureItem> features;
std::vector<float> embedx; std::vector<float> embedx;
......
#include "paddle/fluid/train/custom_trainer/feed/monitor/auc_monitor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
int AucMonitor::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context_ptr) {
Monitor::initialize(config, context_ptr);
_target_name = config["target"].as<std::string>();
_label_name = config["label"].as<std::string>();
_table_size = 1000000;
if (config["table_size"]) {
_table_size = config["table_size"].as<int>();
}
set_table_size(_table_size);
_compute_interval = 3600;
if (config["compute_interval"]) {
uint32_t interval = config["compute_interval"].as<uint32_t>();
if (interval != 3600 || interval != 86400) {
LOG(FATAL) << " AucMonitor config compute_interval just support hour: 3600 or day: 86400. ";
return -1;
}
_compute_interval = interval;
}
}
void AucMonitor::add_data(int epoch_id, const Executor* executor, SampleInstance* instance, size_t num) {
if (executor == nullptr
|| instance == nullptr
|| instance->predicts.empty()
|| instance->labels.empty()
|| num <= 0
|| instance->predicts.size() < num
|| instance->labels.size() < num) {
LOG(FATAL) << "AucMonitor add predict data is invalid, predicts or labels is empty, num[" << num << "]";
return;
}
std::lock_guard<std::mutex> lock(_mutex);
for (int i = 0; i < num; ++i) {
add_unlocked(instance->predicts[i], std::lround(instance->labels[i]));
}
}
bool AucMonitor::need_compute_result(int epoch_id, EpochAccessor* accessor) {
CHECK(accessor != nullptr);
uint64_t epoch_time = accessor->epoch_timestamp(epoch_id);
if (epoch_time % _compute_interval != 0) {
return false;
}
return true;
}
void AucMonitor::compute_result() {
double* table[2] = {&_table[0][0], &_table[1][0]};
for (int i = 0; i < 2; i++) {
Monitor::_context_ptr->environment->all_reduce_arr(table[i], _table_size);
}
double area = 0;
double fp = 0;
double tp = 0;
for (int i = _table_size - 1; i >= 0; i--) {
double newfp = fp + table[0][i];
double newtp = tp + table[1][i];
area += (newfp - fp) * (tp + newtp) / 2;
fp = newfp;
tp = newtp;
}
_auc = area / (fp * tp);
_mae = Monitor::_context_ptr->environment->all_reduce_ele(_local_abserr) / (fp + tp);
_rmse = sqrt(Monitor::_context_ptr->environment->all_reduce_ele(_local_sqrerr) / (fp + tp));
_rmse = sqrt(_rmse / (fp + tp));
_actual_ctr = tp / (fp + tp);
_predicted_ctr = Monitor::_context_ptr->environment->all_reduce_ele(_local_pred) / (fp + tp);
_size = fp + tp;
calculate_bucket_error();
}
std::string AucMonitor::format_result() {
double copc = 0.0;
if (fabs(_predicted_ctr) > 1e-6) {
copc = _actual_ctr / _predicted_ctr;
}
char buf[10240];
snprintf(buf, 10240 * sizeof(char), "%s: AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "
"Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f",
Monitor::_name.c_str(),
_auc,
_bucket_error,
_mae,
_rmse,
_actual_ctr,
_predicted_ctr,
copc,
_size);
return std::string(buf);
}
void AucMonitor::add_unlocked(double pred, int label) {
CHECK(pred >= 0 && pred <= 1) << "pred[" << pred << "] outside of [0,1]";
CHECK(label == 0 || label == 1) << "label[" << label << "] invalid";
_table[label][std::min(int(pred * _table_size), _table_size - 1)]++;
_local_abserr += fabs(pred - label);
_local_sqrerr += (pred - label) * (pred - label);
_local_pred += pred;
}
void AucMonitor::calculate_bucket_error() {
double last_ctr = -1;
double impression_sum = 0;
double ctr_sum = 0.0;
double click_sum = 0.0;
double error_sum = 0.0;
double error_count = 0;
double* table[2] = {&_table[0][0], &_table[1][0]};
for (int i = 0; i < _table_size; i++) {
double click = table[1][i];
double show = table[0][i] + table[1][i];
double ctr = (double)i / _table_size;
if (fabs(ctr - last_ctr) > kMaxSpan) {
last_ctr = ctr;
impression_sum = 0.0;
ctr_sum = 0.0;
click_sum = 0.0;
}
impression_sum += show;
ctr_sum += ctr * show;
click_sum += click;
double adjust_ctr = ctr_sum / impression_sum;
double relative_error = sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum));
if (relative_error < kRelativeErrorBound) {
double actual_ctr = click_sum / impression_sum;
double relative_ctr_error = fabs(actual_ctr / adjust_ctr - 1);
error_sum += relative_ctr_error * impression_sum;
error_count += impression_sum;
last_ctr = -1;
}
}
_bucket_error = error_count > 0 ? error_sum / error_count : 0.0;
}
void AucMonitor::set_table_size(int table_size) {
CHECK(table_size >= 1);
_table_size = table_size;
for (int i = 0; i < 2; i++) {
_table[i] = std::vector<double>();
}
reset();
}
void AucMonitor::reset() {
for (int i = 0; i < 2; i++) {
_table[i].assign(_table_size, 0.0);
}
_local_abserr = 0;
_local_sqrerr = 0;
_local_pred = 0;
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
#pragma once #pragma once
#include <string> #include <string>
#include <cmath> //std::lround
#include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h" #include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h"
namespace paddle { namespace paddle {
...@@ -14,14 +15,13 @@ public: ...@@ -14,14 +15,13 @@ public:
virtual ~AucMonitor() {} virtual ~AucMonitor() {}
virtual int initialize(const YAML::Node& config, virtual int initialize(const YAML::Node& config,
std::shared_ptr<TrainerContext> context_ptr) { std::shared_ptr<TrainerContext> context_ptr) override;
Monitor::initialize(config, context_ptr);
//一些额外配置 对于AUC主要是target && label 信息
return 0;
}
//添加一项记录,统计内容Monitor自行从Executor按需获取 //添加一项记录,统计内容Monitor自行从Executor按需获取
virtual void add_data(int epoch_id, const Executor* executor); virtual void add_data(int epoch_id,
const Executor* executor,
SampleInstance* instance,
size_t num);
//是否开始结果统计 //是否开始结果统计
virtual bool need_compute_result(int epoch_id, EpochAccessor* accessor); virtual bool need_compute_result(int epoch_id, EpochAccessor* accessor);
...@@ -31,6 +31,31 @@ public: ...@@ -31,6 +31,31 @@ public:
virtual std::string format_result(); virtual std::string format_result();
virtual void reset(); virtual void reset();
protected:
std::string _label_name;
std::string _target_name;
std::string _name;
std::string _output_var;
std::mutex _mutex;
double _local_abserr, _local_sqrerr, _local_pred;
double _auc;
double _mae;
double _rmse;
double _actual_ctr, _predicted_ctr;
double _size;
double _bucket_error;
int _table_size;
void add_unlocked(double pred, int label);
private:
void calculate_bucket_error();
void set_table_size(int table_size);
uint32_t _compute_interval;
std::vector<double> _table[2];
static constexpr double kRelativeErrorBound = 0.05;
static constexpr double kMaxSpan = 0.01;
}; };
} // namespace feed } // namespace feed
......
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h" #include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
namespace paddle { namespace paddle {
namespace custom_trainer { namespace custom_trainer {
...@@ -15,12 +17,13 @@ public: ...@@ -15,12 +17,13 @@ public:
virtual int initialize(const YAML::Node& config, virtual int initialize(const YAML::Node& config,
std::shared_ptr<TrainerContext> context_ptr) { std::shared_ptr<TrainerContext> context_ptr) {
_name = conf["name"].as<std::string>(); _name = config["name"].as<std::string>();
_context_ptr = context_ptr;
return 0; return 0;
} }
//添加一项记录,统计内容Monitor自行从Executor按需获取 //添加一项记录,统计内容Monitor自行从Executor按需获取
virtual void add_data(int epoch_id, const Executor* executor) = 0; virtual void add_data(int epoch_id, const Executor* executor, SampleInstance* instance, size_t num) = 0;
//是否对于当前epoch_id进行结果统计 //是否对于当前epoch_id进行结果统计
virtual bool need_compute_result(int epoch_id, EpochAccessor* accessor) = 0; virtual bool need_compute_result(int epoch_id, EpochAccessor* accessor) = 0;
...@@ -37,6 +40,7 @@ public: ...@@ -37,6 +40,7 @@ public:
protected: protected:
std::string _name; std::string _name;
std::shared_ptr<TrainerContext> _context_ptr;
}; };
REGISTER_REGISTERER(Monitor); REGISTER_REGISTERER(Monitor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册