未验证 提交 7460a891 编写于 作者: F Fan Zhang 提交者: GitHub

[PSLIB] Add Metrics Module, Support User-defined Add Metric (#38230)

* 12.3 first add metrics module

* add Mask/MultiTask

* add WuAUC

* [PSLIB] Update WuAUC Compute

* [PSLIB] Change WuAUC Compute Mehod

* [PSLIB] Clean WuAUC Compute

* [PSLIB] Clean Metric Module Unused Code

* mv metric instance

* [PSLIB] Add Metrics Module, Support User-defined Add Metric (#38789)

* [PSLIB] Add Metrics Module, Support User-defined Add Metric

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* modify role_maker

* update CMakeLists.txt
上级 d3011c75
......@@ -188,7 +188,7 @@ if(WITH_DISTRIBUTE)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper metrics lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......@@ -199,7 +199,7 @@ else()
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method
lod_rank_table fs shell fleet_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
......
......@@ -232,6 +232,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_id_ = 0;
this->thread_num_ = 1;
this->parse_ins_id_ = false;
this->parse_uid_ = false;
this->parse_content_ = false;
this->parse_logkey_ = false;
this->enable_pv_merge_ = false;
......@@ -362,6 +363,11 @@ void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
}
template <typename T>
void InMemoryDataFeed<T>::SetParseUid(bool parse_uid) {
parse_uid_ = parse_uid;
}
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX
......@@ -838,6 +844,7 @@ void MultiSlotInMemoryDataFeed::Init(
use_slots_shape_.push_back(local_shape);
}
}
uid_slot_ = multi_slot_desc.uid_slot();
feed_vec_.resize(use_slots_.size());
pipe_command_ = data_feed_desc.pipe_command();
finish_init_ = true;
......@@ -929,6 +936,17 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
if (parse_uid_ && all_slots_[i] == uid_slot_) {
PADDLE_ENFORCE(num == 1 && all_slots_type_[i][0] == 'u',
"The uid has to be uint64 and single.\n"
"please check this error line: %s",
str);
char* uidptr = endptr;
uint64_t feasign = (uint64_t)strtoull(uidptr, &uidptr, 10);
instance->uid_ = feasign;
}
if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) {
......
......@@ -92,6 +92,7 @@ struct Record {
uint64_t search_id;
uint32_t rank;
uint32_t cmatch;
std::string uid_;
};
struct PvInstanceObject {
......@@ -157,6 +158,7 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseUid(bool parse_uid) {}
virtual void SetParseContent(bool parse_content) {}
virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
......@@ -232,6 +234,7 @@ class DataFeed {
std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_;
platform::Place place_;
std::string uid_slot_;
};
// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
......@@ -293,6 +296,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseUid(bool parse_uid);
virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
......@@ -307,6 +311,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_;
int thread_num_;
bool parse_ins_id_;
bool parse_uid_;
bool parse_content_;
bool parse_logkey_;
bool enable_pv_merge_;
......@@ -471,7 +476,7 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
for (size_t& x : offset) {
uint64_t t;
ar >> t;
x = (size_t)t;
x = static_cast<size_t>(t);
}
#endif
ar >> ins.MutableFloatData();
......
......@@ -22,7 +22,10 @@ message Slot {
repeated int32 shape = 5; // we can define N-D Tensor
}
message MultiSlotDesc { repeated Slot slots = 1; }
message MultiSlotDesc {
repeated Slot slots = 1;
optional string uid_slot = 2;
}
message DataFeedDesc {
optional string name = 1;
......
......@@ -54,6 +54,8 @@ DatasetImpl<T>::DatasetImpl() {
parse_logkey_ = false;
preload_thread_num_ = 0;
global_index_ = 0;
shuffle_by_uid_ = false;
parse_uid_ = false;
}
// set filelist, file_idx_ will reset to zero.
......@@ -147,6 +149,12 @@ void DatasetImpl<T>::SetMergeBySid(bool is_merge) {
merge_by_sid_ = is_merge;
}
template <typename T>
void DatasetImpl<T>::SetShuffleByUid(bool enable_shuffle_uid) {
shuffle_by_uid_ = enable_shuffle_uid;
parse_uid_ = true;
}
template <typename T>
void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge;
......@@ -386,11 +394,14 @@ void DatasetImpl<T>::GlobalShuffle(int thread_num) {
<< input_channel_->Size();
auto get_client_id = [this, fleet_ptr](const T& data) -> size_t {
if (!this->merge_by_insid_) {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} else {
if (this->merge_by_insid_) {
return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) %
this->trainer_num_;
} else if (this->shuffle_by_uid_) {
return XXH64(data.uid_.data(), data.uid_.length(), 0) %
this->trainer_num_;
} else {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
}
};
......@@ -618,6 +629,7 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFileListIndex(&file_idx_);
readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseUid(parse_uid_);
readers_[i]->SetParseContent(parse_content_);
readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_);
......@@ -686,6 +698,7 @@ void DatasetImpl<T>::CreatePreLoadReaders() {
preload_readers_[i]->SetFileListIndex(&file_idx_);
preload_readers_[i]->SetFileList(filelist_);
preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseUid(parse_uid_);
preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetParseLogKey(parse_logkey_);
preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_);
......
......@@ -68,6 +68,7 @@ class Dataset {
virtual void SetParseLogKey(bool parse_logkey) = 0;
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
virtual void SetShuffleByUid(bool enable_shuffle_uid) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
......@@ -175,6 +176,7 @@ class DatasetImpl : public Dataset {
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetMergeBySid(bool is_merge);
virtual void SetShuffleByUid(bool enable_shuffle_uid);
virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
......@@ -263,6 +265,8 @@ class DatasetImpl : public Dataset {
bool parse_content_;
bool parse_logkey_;
bool merge_by_sid_;
bool shuffle_by_uid_;
bool parse_uid_;
bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update
size_t merge_size_;
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/string/string_helper.h"
......@@ -25,7 +26,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
for (int i = 0; i < param_.sparse_table_size(); ++i) {
......@@ -780,6 +780,21 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}
/**
* @brief add auc monitor
*/
inline void AddAucMonitor(const Scope* scope, const platform::Place& place) {
auto metric_ptr = Metric::GetInstance();
auto& metric_list = metric_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(scope, place);
}
}
void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files";
platform::SetNumThreads(1);
......@@ -877,6 +892,11 @@ void DownpourWorker::TrainFiles() {
}
}
// add data for MetricMsg
if (Metric::GetInstance() != nullptr) {
AddAucMonitor(thread_scope_, place_);
}
// check inf and nan
for (std::string& var_name : check_nan_var_names_) {
Variable* var = thread_scope_->FindVar(var_name);
......
......@@ -15,8 +15,10 @@ endif(WITH_BOX_PS)
if(WITH_GLOO)
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
else()
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
endif(WITH_GLOO)
cc_test(test_fleet SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell)
......@@ -105,6 +105,11 @@ enum GlooStoreType { HDFS, HTTP };
class GlooWrapper {
public:
static std::shared_ptr<GlooWrapper> GetInstance() {
static auto s_instance = std::make_shared<GlooWrapper>();
return s_instance;
}
GlooWrapper() {}
virtual ~GlooWrapper() {}
......@@ -153,6 +158,8 @@ class GlooWrapper {
#endif
}
bool IsInitialized() { return is_initialized_; }
template <typename T>
std::vector<T> AllReduce(std::vector<T>& sendbuf, // NOLINT
const std::string& mode = "sum") { // NOLINT
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/fleet/metrics.h"
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
std::shared_ptr<Metric> Metric::s_instance_ = nullptr;
void BasicAucCalculator::init(int table_size) {
set_table_size(table_size);
// init CPU memory
for (int i = 0; i < 2; i++) {
_table[i] = std::vector<double>();
}
// reset
reset();
}
void BasicAucCalculator::reset() {
// reset CPU counter
for (int i = 0; i < 2; i++) {
_table[i].assign(_table_size, 0.0);
}
_local_abserr = 0;
_local_sqrerr = 0;
_local_pred = 0;
}
void BasicAucCalculator::add_data(const float* d_pred, const int64_t* d_label,
int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
h_pred.resize(batch_size);
h_label.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_unlock_data(h_pred[i], h_label[i]);
}
}
void BasicAucCalculator::add_unlock_data(double pred, int label) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
int pos = std::min(static_cast<int>(pred * _table_size), _table_size - 1);
PADDLE_ENFORCE_GE(
pos, 0,
platform::errors::PreconditionNotMet(
"pos must be equal or greater than 0, but its value is: %d", pos));
PADDLE_ENFORCE_LT(
pos, _table_size,
platform::errors::PreconditionNotMet(
"pos must be less than table_size, but its value is: %d", pos));
_local_abserr += fabs(pred - label);
_local_sqrerr += (pred - label) * (pred - label);
_local_pred += pred;
++_table[label][pos];
}
// add mask data
void BasicAucCalculator::add_mask_data(const float* d_pred,
const int64_t* d_label,
const int64_t* d_mask, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
thread_local std::vector<int64_t> h_mask;
h_pred.resize(batch_size);
h_label.resize(batch_size);
h_mask.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
memcpy(h_mask.data(), d_mask, sizeof(int64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
if (h_mask[i]) {
add_unlock_data(h_pred[i], h_label[i]);
}
}
}
void BasicAucCalculator::compute() {
#if defined(PADDLE_WITH_GLOO)
double area = 0;
double fp = 0;
double tp = 0;
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (!gloo_wrapper->IsInitialized()) {
VLOG(0) << "GLOO is not inited";
gloo_wrapper->Init();
}
if (gloo_wrapper->Size() > 1) {
auto neg_table = gloo_wrapper->AllReduce(_table[0], "sum");
auto pos_table = gloo_wrapper->AllReduce(_table[1], "sum");
for (int i = _table_size - 1; i >= 0; i--) {
double newfp = fp + neg_table[i];
double newtp = tp + pos_table[i];
area += (newfp - fp) * (tp + newtp) / 2;
fp = newfp;
tp = newtp;
}
} else {
for (int i = _table_size - 1; i >= 0; i--) {
double newfp = fp + _table[0][i];
double newtp = tp + _table[1][i];
area += (newfp - fp) * (tp + newtp) / 2;
fp = newfp;
tp = newtp;
}
}
if (fp < 1e-3 || tp < 1e-3) {
_auc = -0.5; // which means all nonclick or click
} else {
_auc = area / (fp * tp);
}
if (gloo_wrapper->Size() > 1) {
// allreduce sum
std::vector<double> local_abserr_vec(1, _local_abserr);
std::vector<double> local_sqrerr_vec(1, _local_sqrerr);
std::vector<double> local_pred_vec(1, _local_pred);
auto global_abserr_vec = gloo_wrapper->AllReduce(local_abserr_vec, "sum");
auto global_sqrerr_vec = gloo_wrapper->AllReduce(local_sqrerr_vec, "sum");
auto global_pred_vec = gloo_wrapper->AllReduce(local_pred_vec, "sum");
_mae = global_abserr_vec[0] / (fp + tp);
_rmse = sqrt(global_sqrerr_vec[0] / (fp + tp));
_predicted_ctr = global_pred_vec[0] / (fp + tp);
} else {
_mae = _local_abserr / (fp + tp);
_rmse = sqrt(_local_sqrerr / (fp + tp));
_predicted_ctr = _local_pred / (fp + tp);
}
_actual_ctr = tp / (fp + tp);
_size = fp + tp;
calculate_bucket_error();
#endif
}
void BasicAucCalculator::calculate_bucket_error() {
#if defined(PADDLE_WITH_GLOO)
double last_ctr = -1;
double impression_sum = 0;
double ctr_sum = 0.0;
double click_sum = 0.0;
double error_sum = 0.0;
double error_count = 0;
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (gloo_wrapper->Size() > 1) {
auto neg_table = gloo_wrapper->AllReduce(_table[0], "sum");
auto pos_table = gloo_wrapper->AllReduce(_table[1], "sum");
for (int i = 0; i < _table_size; i++) {
double click = pos_table[i];
double show = neg_table[i] + pos_table[i];
double ctr = static_cast<double>(i) / _table_size;
if (fabs(ctr - last_ctr) > kMaxSpan) {
last_ctr = ctr;
impression_sum = 0.0;
ctr_sum = 0.0;
click_sum = 0.0;
}
impression_sum += show;
ctr_sum += ctr * show;
click_sum += click;
double adjust_ctr = ctr_sum / impression_sum;
double relative_error =
sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum));
if (relative_error < kRelativeErrorBound) {
double actual_ctr = click_sum / impression_sum;
double relative_ctr_error = fabs(actual_ctr / adjust_ctr - 1);
error_sum += relative_ctr_error * impression_sum;
error_count += impression_sum;
last_ctr = -1;
}
}
} else {
double* table[2] = {&_table[0][0], &_table[1][0]};
for (int i = 0; i < _table_size; i++) {
double click = table[1][i];
double show = table[0][i] + table[1][i];
double ctr = static_cast<double>(i) / _table_size;
if (fabs(ctr - last_ctr) > kMaxSpan) {
last_ctr = ctr;
impression_sum = 0.0;
ctr_sum = 0.0;
click_sum = 0.0;
}
impression_sum += show;
ctr_sum += ctr * show;
click_sum += click;
double adjust_ctr = ctr_sum / impression_sum;
double relative_error =
sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum));
if (relative_error < kRelativeErrorBound) {
double actual_ctr = click_sum / impression_sum;
double relative_ctr_error = fabs(actual_ctr / adjust_ctr - 1);
error_sum += relative_ctr_error * impression_sum;
error_count += impression_sum;
last_ctr = -1;
}
}
}
_bucket_error = error_count > 0 ? error_sum / error_count : 0.0;
#endif
}
void BasicAucCalculator::reset_records() {
// reset wuauc_records_
wuauc_records_.clear();
_user_cnt = 0;
_size = 0;
_uauc = 0;
_wuauc = 0;
}
// add uid data
void BasicAucCalculator::add_uid_data(const float* d_pred,
const int64_t* d_label,
const int64_t* d_uid, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
thread_local std::vector<uint64_t> h_uid;
h_pred.resize(batch_size);
h_label.resize(batch_size);
h_uid.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
memcpy(h_uid.data(), d_uid, sizeof(uint64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_uid_unlock_data(h_pred[i], h_label[i], static_cast<uint64_t>(h_uid[i]));
}
}
void BasicAucCalculator::add_uid_unlock_data(double pred, int label,
uint64_t uid) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
WuaucRecord record;
record.uid_ = uid;
record.label_ = label;
record.pred_ = pred;
wuauc_records_.emplace_back(std::move(record));
}
void BasicAucCalculator::computeWuAuc() {
std::sort(wuauc_records_.begin(), wuauc_records_.end(),
[](const WuaucRecord& lhs, const WuaucRecord& rhs) {
if (lhs.uid_ == rhs.uid_) {
if (lhs.pred_ == rhs.pred_) {
return lhs.label_ < rhs.label_;
} else {
return lhs.pred_ > rhs.pred_;
}
} else {
return lhs.uid_ > rhs.uid_;
}
});
WuaucRocData roc_data;
uint64_t prev_uid = 0;
size_t prev_pos = 0;
for (size_t i = 0; i < wuauc_records_.size(); ++i) {
if (wuauc_records_[i].uid_ != prev_uid) {
std::vector<WuaucRecord> single_user_recs(
wuauc_records_.begin() + prev_pos, wuauc_records_.begin() + i);
roc_data = computeSingelUserAuc(single_user_recs);
if (roc_data.auc_ != -1) {
double ins_num = (roc_data.tp_ + roc_data.fp_);
_user_cnt += 1;
_size += ins_num;
_uauc += roc_data.auc_;
_wuauc += roc_data.auc_ * ins_num;
}
prev_uid = wuauc_records_[i].uid_;
prev_pos = i;
}
}
std::vector<WuaucRecord> single_user_recs(wuauc_records_.begin() + prev_pos,
wuauc_records_.end());
roc_data = computeSingelUserAuc(single_user_recs);
if (roc_data.auc_ != -1) {
double ins_num = (roc_data.tp_ + roc_data.fp_);
_user_cnt += 1;
_size += ins_num;
_uauc += roc_data.auc_;
_wuauc += roc_data.auc_ * ins_num;
}
}
BasicAucCalculator::WuaucRocData BasicAucCalculator::computeSingelUserAuc(
const std::vector<WuaucRecord>& records) {
double tp = 0.0;
double fp = 0.0;
double newtp = 0.0;
double newfp = 0.0;
double area = 0.0;
double auc = -1;
size_t i = 0;
while (i < records.size()) {
newtp = tp;
newfp = fp;
if (records[i].label_ == 1) {
newtp += 1;
} else {
newfp += 1;
}
// check i+1
while (i < records.size() - 1 && records[i].pred_ == records[i + 1].pred_) {
if (records[i + 1].label_ == 1) {
newtp += 1;
} else {
newfp += 1;
}
i += 1;
}
area += (newfp - fp) * (tp + newtp) / 2.0;
tp = newtp;
fp = newfp;
i += 1;
}
if (tp > 0 && fp > 0) {
auc = area / (fp * tp + 1e-9);
} else {
auc = -1;
}
return {tp, fp, auc};
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/string/string_helper.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/allreduce.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace framework {
class BasicAucCalculator {
public:
BasicAucCalculator() {}
struct WuaucRecord {
uint64_t uid_;
int label_;
float pred_;
};
struct WuaucRocData {
double tp_;
double fp_;
double auc_;
};
void init(int table_size);
void init_wuauc(int table_size);
void reset();
void reset_records();
// add single data in CPU with LOCK, deprecated
void add_unlock_data(double pred, int label);
void add_uid_unlock_data(double pred, int label, uint64_t uid);
// add batch data
void add_data(const float* d_pred, const int64_t* d_label, int batch_size,
const paddle::platform::Place& place);
// add mask data
void add_mask_data(const float* d_pred, const int64_t* d_label,
const int64_t* d_mask, int batch_size,
const paddle::platform::Place& place);
// add uid data
void add_uid_data(const float* d_pred, const int64_t* d_label,
const int64_t* d_uid, int batch_size,
const paddle::platform::Place& place);
void compute();
void computeWuAuc();
WuaucRocData computeSingelUserAuc(const std::vector<WuaucRecord>& records);
int table_size() const { return _table_size; }
double bucket_error() const { return _bucket_error; }
double auc() const { return _auc; }
double uauc() const { return _uauc; }
double wuauc() const { return _wuauc; }
double mae() const { return _mae; }
double actual_ctr() const { return _actual_ctr; }
double predicted_ctr() const { return _predicted_ctr; }
double user_cnt() const { return _user_cnt; }
double size() const { return _size; }
double rmse() const { return _rmse; }
std::unordered_set<uint64_t> uid_keys() const { return _uid_keys; }
// lock and unlock
std::mutex& table_mutex(void) { return _table_mutex; }
private:
void calculate_bucket_error();
protected:
double _local_abserr = 0;
double _local_sqrerr = 0;
double _local_pred = 0;
double _auc = 0;
double _uauc = 0;
double _wuauc = 0;
double _mae = 0;
double _rmse = 0;
double _actual_ctr = 0;
double _predicted_ctr = 0;
double _size;
double _user_cnt = 0;
double _bucket_error = 0;
std::unordered_set<uint64_t> _uid_keys;
private:
void set_table_size(int table_size) { _table_size = table_size; }
int _table_size;
std::vector<double> _table[2];
std::vector<WuaucRecord> wuauc_records_;
static constexpr double kRelativeErrorBound = 0.05;
static constexpr double kMaxSpan = 0.01;
std::mutex _table_mutex;
};
class Metric {
public:
virtual ~Metric() {}
Metric() { fprintf(stdout, "init fleet Metric\n"); }
class MetricMsg {
public:
MetricMsg() {}
MetricMsg(const std::string& label_varname, const std::string& pred_varname,
int metric_phase, int bucket_size = 1000000)
: label_varname_(label_varname),
pred_varname_(pred_varname),
metric_phase_(metric_phase) {
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
virtual ~MetricMsg() {}
int MetricPhase() const { return metric_phase_; }
BasicAucCalculator* GetCalculator() { return calculator; }
// add_data
virtual void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) {
int label_len = 0;
const int64_t* label_data = NULL;
int pred_len = 0;
const float* pred_data = NULL;
get_data<int64_t>(exe_scope, label_varname_, &label_data, &label_len);
get_data<float>(exe_scope, pred_varname_, &pred_data, &pred_len);
PADDLE_ENFORCE_EQ(label_len, pred_len,
platform::errors::PreconditionNotMet(
"the predict data length should be consistent with "
"the label data length"));
calculator->add_data(pred_data, label_data, label_len, place);
}
// get_data
template <class T = float>
static void get_data(const Scope* exe_scope, const std::string& varname,
const T** data, int* len) {
auto* var = exe_scope->FindVar(varname.c_str());
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound(
"Error: var %s is not found in scope.", varname.c_str()));
auto& cpu_tensor = var->Get<LoDTensor>();
*data = cpu_tensor.data<T>();
*len = cpu_tensor.numel();
}
template <class T = float>
static void get_data(const Scope* exe_scope, const std::string& varname,
std::vector<T>* data) {
auto* var = exe_scope->FindVar(varname.c_str());
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound(
"Error: var %s is not found in scope.", varname.c_str()));
auto& cpu_tensor = var->Get<LoDTensor>();
auto* cpu_data = cpu_tensor.data<T>();
auto len = cpu_tensor.numel();
data->resize(len);
memcpy(data->data(), cpu_data, sizeof(T) * len);
}
// parse_cmatch_rank
static inline std::pair<int, int> parse_cmatch_rank(uint64_t x) {
// only consider ignore_rank=True
return std::make_pair(static_cast<int>(x), 0);
// first 32 bit store cmatch and second 32 bit store rank
// return std::make_pair(static_cast<int>(x >> 32),
// static_cast<int>(x & 0xff));
}
protected:
std::string label_varname_;
std::string pred_varname_;
int metric_phase_;
BasicAucCalculator* calculator;
};
class WuAucMetricMsg : public MetricMsg {
public:
WuAucMetricMsg(const std::string& label_varname,
const std::string& pred_varname,
const std::string& uid_varname, int metric_phase,
int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
uid_varname_ = uid_varname;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
}
virtual ~WuAucMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
int label_len = 0;
const int64_t* label_data = NULL;
get_data<int64_t>(exe_scope, label_varname_, &label_data, &label_len);
int pred_len = 0;
const float* pred_data = NULL;
get_data<float>(exe_scope, pred_varname_, &pred_data, &pred_len);
int uid_len = 0;
const int64_t* uid_data = NULL;
get_data<int64_t>(exe_scope, uid_varname_, &uid_data, &uid_len);
PADDLE_ENFORCE_EQ(label_len, uid_len,
platform::errors::PreconditionNotMet(
"the predict data length should be consistent with "
"the label data length"));
auto cal = GetCalculator();
cal->add_uid_data(pred_data, label_data, uid_data, label_len, place);
}
protected:
std::string uid_varname_;
};
class MultiTaskMetricMsg : public MetricMsg {
public:
MultiTaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname_list, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
int bucket_size = 1000000) {
label_varname_ = label_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
const std::vector<std::string>& cur_cmatch_rank =
string::split_string(cmatch_rank, "_");
PADDLE_ENFORCE_EQ(
cur_cmatch_rank.size(), 2,
platform::errors::PreconditionNotMet(
"illegal multitask auc spec: %s", cmatch_rank.c_str()));
cmatch_rank_v.emplace_back(atoi(cur_cmatch_rank[0].c_str()),
atoi(cur_cmatch_rank[1].c_str()));
}
for (const auto& pred_varname : string::split_string(pred_varname_list)) {
pred_v.emplace_back(pred_varname);
}
PADDLE_ENFORCE_EQ(cmatch_rank_v.size(), pred_v.size(),
platform::errors::PreconditionNotMet(
"cmatch_rank's size [%lu] should be equal to pred "
"list's size [%lu], but ther are not equal",
cmatch_rank_v.size(), pred_v.size()));
}
virtual ~MultiTaskMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
std::vector<int64_t> cmatch_rank_data;
get_data<int64_t>(exe_scope, cmatch_rank_varname_, &cmatch_rank_data);
std::vector<int64_t> label_data;
get_data<int64_t>(exe_scope, label_varname_, &label_data);
size_t batch_size = cmatch_rank_data.size();
PADDLE_ENFORCE_EQ(
batch_size, label_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: batch_size[%lu] and label_data[%lu]",
batch_size, label_data.size()));
std::vector<std::vector<float>> pred_data_list(pred_v.size());
for (size_t i = 0; i < pred_v.size(); ++i) {
get_data<float>(exe_scope, pred_v[i], &pred_data_list[i]);
}
for (size_t i = 0; i < pred_data_list.size(); ++i) {
PADDLE_ENFORCE_EQ(
batch_size, pred_data_list[i].size(),
platform::errors::PreconditionNotMet(
"illegal batch size: batch_size[%lu] and pred_data[%lu]",
batch_size, pred_data_list[i].size()));
}
auto cal = GetCalculator();
std::lock_guard<std::mutex> lock(cal->table_mutex());
for (size_t i = 0; i < batch_size; ++i) {
auto cmatch_rank_it =
std::find(cmatch_rank_v.begin(), cmatch_rank_v.end(),
parse_cmatch_rank(cmatch_rank_data[i]));
if (cmatch_rank_it != cmatch_rank_v.end()) {
cal->add_unlock_data(pred_data_list[std::distance(
cmatch_rank_v.begin(), cmatch_rank_it)][i],
label_data[i]);
}
}
}
protected:
std::vector<std::pair<int, int>> cmatch_rank_v;
std::vector<std::string> pred_v;
std::string cmatch_rank_varname_;
};
class CmatchRankMetricMsg : public MetricMsg {
public:
CmatchRankMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
bool ignore_rank = false, int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
metric_phase_ = metric_phase;
ignore_rank_ = ignore_rank;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
if (ignore_rank) { // CmatchAUC
cmatch_rank_v.emplace_back(atoi(cmatch_rank.c_str()), 0);
continue;
}
const std::vector<std::string>& cur_cmatch_rank =
string::split_string(cmatch_rank, "_");
PADDLE_ENFORCE_EQ(
cur_cmatch_rank.size(), 2,
platform::errors::PreconditionNotMet(
"illegal cmatch_rank auc spec: %s", cmatch_rank.c_str()));
cmatch_rank_v.emplace_back(atoi(cur_cmatch_rank[0].c_str()),
atoi(cur_cmatch_rank[1].c_str()));
}
}
virtual ~CmatchRankMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
std::vector<int64_t> cmatch_rank_data;
get_data<int64_t>(exe_scope, cmatch_rank_varname_, &cmatch_rank_data);
std::vector<int64_t> label_data;
get_data<int64_t>(exe_scope, label_varname_, &label_data);
std::vector<float> pred_data;
get_data<float>(exe_scope, pred_varname_, &pred_data);
size_t batch_size = cmatch_rank_data.size();
PADDLE_ENFORCE_EQ(
batch_size, label_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and label_data[%lu]",
batch_size, label_data.size()));
PADDLE_ENFORCE_EQ(
batch_size, pred_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and pred_data[%lu]",
batch_size, pred_data.size()));
auto cal = GetCalculator();
std::lock_guard<std::mutex> lock(cal->table_mutex());
for (size_t i = 0; i < batch_size; ++i) {
const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]);
for (size_t j = 0; j < cmatch_rank_v.size(); ++j) {
bool is_matched = false;
if (ignore_rank_) {
is_matched = cmatch_rank_v[j].first == cur_cmatch_rank.first;
} else {
is_matched = cmatch_rank_v[j] == cur_cmatch_rank;
}
if (is_matched) {
cal->add_unlock_data(pred_data[i], label_data[i]);
break;
}
}
}
}
protected:
std::vector<std::pair<int, int>> cmatch_rank_v;
std::string cmatch_rank_varname_;
bool ignore_rank_;
};
class MaskMetricMsg : public MetricMsg {
public:
MaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int metric_phase,
const std::string& mask_varname, int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
mask_varname_ = mask_varname;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
virtual ~MaskMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
int label_len = 0;
const int64_t* label_data = NULL;
get_data<int64_t>(exe_scope, label_varname_, &label_data, &label_len);
int pred_len = 0;
const float* pred_data = NULL;
get_data<float>(exe_scope, pred_varname_, &pred_data, &pred_len);
int mask_len = 0;
const int64_t* mask_data = NULL;
get_data<int64_t>(exe_scope, mask_varname_, &mask_data, &mask_len);
PADDLE_ENFORCE_EQ(label_len, mask_len,
platform::errors::PreconditionNotMet(
"the predict data length should be consistent with "
"the label data length"));
auto cal = GetCalculator();
cal->add_mask_data(pred_data, label_data, mask_data, label_len, place);
}
protected:
std::string mask_varname_;
};
class CmatchRankMaskMetricMsg : public MetricMsg {
public:
CmatchRankMaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
bool ignore_rank = false,
const std::string& mask_varname = "",
int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
metric_phase_ = metric_phase;
ignore_rank_ = ignore_rank;
mask_varname_ = mask_varname;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
if (ignore_rank) { // CmatchAUC
cmatch_rank_v.emplace_back(atoi(cmatch_rank.c_str()), 0);
continue;
}
const std::vector<std::string>& cur_cmatch_rank =
string::split_string(cmatch_rank, "_");
PADDLE_ENFORCE_EQ(
cur_cmatch_rank.size(), 2,
platform::errors::PreconditionNotMet(
"illegal cmatch_rank auc spec: %s", cmatch_rank.c_str()));
cmatch_rank_v.emplace_back(atoi(cur_cmatch_rank[0].c_str()),
atoi(cur_cmatch_rank[1].c_str()));
}
}
virtual ~CmatchRankMaskMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
std::vector<int64_t> cmatch_rank_data;
get_data<int64_t>(exe_scope, cmatch_rank_varname_, &cmatch_rank_data);
std::vector<int64_t> label_data;
get_data<int64_t>(exe_scope, label_varname_, &label_data);
std::vector<float> pred_data;
get_data<float>(exe_scope, pred_varname_, &pred_data);
size_t batch_size = cmatch_rank_data.size();
PADDLE_ENFORCE_EQ(
batch_size, label_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and label_data[%lu]",
batch_size, label_data.size()));
PADDLE_ENFORCE_EQ(
batch_size, pred_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and pred_data[%lu]",
batch_size, pred_data.size()));
std::vector<int64_t> mask_data;
if (!mask_varname_.empty()) {
get_data<int64_t>(exe_scope, mask_varname_, &mask_data);
PADDLE_ENFORCE_EQ(
batch_size, mask_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and mask_data[%lu]",
batch_size, mask_data.size()));
}
auto cal = GetCalculator();
std::lock_guard<std::mutex> lock(cal->table_mutex());
for (size_t i = 0; i < batch_size; ++i) {
const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]);
for (size_t j = 0; j < cmatch_rank_v.size(); ++j) {
if (!mask_data.empty() && !mask_data[i]) {
continue;
}
bool is_matched = false;
if (ignore_rank_) {
is_matched = cmatch_rank_v[j].first == cur_cmatch_rank.first;
} else {
is_matched = cmatch_rank_v[j] == cur_cmatch_rank;
}
if (is_matched) {
cal->add_unlock_data(pred_data[i], label_data[i]);
break;
}
}
}
}
protected:
std::vector<std::pair<int, int>> cmatch_rank_v;
std::string cmatch_rank_varname_;
bool ignore_rank_;
std::string mask_varname_;
};
static std::shared_ptr<Metric> GetInstance() {
// PADDLE_ENFORCE_EQ(
// s_instance_ == nullptr, false,
// platform::errors::PreconditionNotMet(
// "GetInstance failed in Metric, you should use SetInstance
// firstly"));
return s_instance_;
}
static std::shared_ptr<Metric> SetInstance() {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (nullptr == s_instance_) {
VLOG(3) << "s_instance_ is null";
s_instance_.reset(new paddle::framework::Metric());
} else {
LOG(WARNING) << "You have already used SetInstance() before";
}
return s_instance_;
}
const std::vector<std::string> GetMetricNameList(
int metric_phase = -1) const {
VLOG(0) << "Want to Get metric phase: " << metric_phase;
if (metric_phase == -1) {
return metric_name_list_;
} else {
std::vector<std::string> ret;
for (const auto& name : metric_name_list_) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(
iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
if (iter->second->MetricPhase() == metric_phase) {
VLOG(3) << name << "'s phase is " << iter->second->MetricPhase()
<< ", we want";
ret.push_back(name);
} else {
VLOG(3) << name << "'s phase is " << iter->second->MetricPhase()
<< ", not we want";
}
}
return ret;
}
}
int Phase() const { return phase_; }
int PhaseNum() const { return phase_num_; }
void FlipPhase() { phase_ = (phase_ + 1) % phase_num_; }
std::map<std::string, MetricMsg*>& GetMetricList() { return metric_lists_; }
void InitMetric(const std::string& method, const std::string& name,
const std::string& label_varname,
const std::string& pred_varname,
const std::string& cmatch_rank_varname,
const std::string& mask_varname,
const std::string& uid_varname, int metric_phase,
const std::string& cmatch_rank_group, bool ignore_rank,
int bucket_size = 1000000) {
if (method == "AucCalculator") {
metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname,
metric_phase, bucket_size));
} else if (method == "MultiTaskAucCalculator") {
metric_lists_.emplace(
name, new MultiTaskMetricMsg(label_varname, pred_varname,
metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size));
} else if (method == "CmatchRankAucCalculator") {
metric_lists_.emplace(name, new CmatchRankMetricMsg(
label_varname, pred_varname, metric_phase,
cmatch_rank_group, cmatch_rank_varname,
ignore_rank, bucket_size));
} else if (method == "MaskAucCalculator") {
metric_lists_.emplace(
name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
mask_varname, bucket_size));
} else if (method == "CmatchRankMaskAucCalculator") {
metric_lists_.emplace(name, new CmatchRankMaskMetricMsg(
label_varname, pred_varname, metric_phase,
cmatch_rank_group, cmatch_rank_varname,
ignore_rank, mask_varname, bucket_size));
} else if (method == "WuAucCalculator") {
metric_lists_.emplace(
name, new WuAucMetricMsg(label_varname, pred_varname, uid_varname,
metric_phase, bucket_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"PSLIB Metrics only support AucCalculator, MultiTaskAucCalculator, "
"CmatchRankAucCalculator, MaskAucCalculator, WuAucCalculator and "
"CmatchRankMaskAucCalculator"));
}
metric_name_list_.emplace_back(name);
}
const std::vector<float> GetMetricMsg(const std::string& name) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
std::vector<float> metric_return_values_(8, 0.0);
auto* auc_cal_ = iter->second->GetCalculator();
auc_cal_->compute();
metric_return_values_[0] = auc_cal_->auc();
metric_return_values_[1] = auc_cal_->bucket_error();
metric_return_values_[2] = auc_cal_->mae();
metric_return_values_[3] = auc_cal_->rmse();
metric_return_values_[4] = auc_cal_->actual_ctr();
metric_return_values_[5] = auc_cal_->predicted_ctr();
metric_return_values_[6] =
auc_cal_->actual_ctr() / auc_cal_->predicted_ctr();
metric_return_values_[7] = auc_cal_->size();
auc_cal_->reset();
return metric_return_values_;
}
const std::vector<float> GetWuAucMetricMsg(const std::string& name) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
VLOG(0) << "begin GetWuAucMetricMsg";
std::vector<float> metric_return_values_(6, 0.0);
auto* auc_cal_ = iter->second->GetCalculator();
auc_cal_->computeWuAuc();
metric_return_values_[0] = auc_cal_->user_cnt();
metric_return_values_[1] = auc_cal_->size();
metric_return_values_[2] = auc_cal_->uauc();
metric_return_values_[3] = auc_cal_->wuauc();
metric_return_values_[4] =
metric_return_values_[2] / (metric_return_values_[0] + 1e-10);
metric_return_values_[5] =
metric_return_values_[3] / (metric_return_values_[1] + 1e-10);
#if defined(PADDLE_WITH_GLOO)
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (gloo_wrapper->Size() > 1) {
auto global_metric_return_values_ =
gloo_wrapper->AllReduce(metric_return_values_, "sum");
global_metric_return_values_[4] =
global_metric_return_values_[2] /
(global_metric_return_values_[0] + 1e-10);
global_metric_return_values_[5] =
global_metric_return_values_[3] /
(global_metric_return_values_[1] + 1e-10);
auc_cal_->reset_records();
return global_metric_return_values_;
} else {
auc_cal_->reset_records();
return metric_return_values_;
}
#else
auc_cal_->reset_records();
return metric_return_values_;
#endif
}
private:
static std::shared_ptr<Metric> s_instance_;
// Metric Related
int phase_ = 1;
int phase_num_ = 2;
std::map<std::string, MetricMsg*> metric_lists_;
std::vector<std::string> metric_name_list_;
};
} // namespace framework
} // namespace paddle
......@@ -71,6 +71,10 @@ ELSE()
set(STREAM_CALLBACK_DEPS)
ENDIF()
if(WITH_GLOO)
cc_library(gloo_context SRCS gloo_context.cc DEPS framework_proto gloo_wrapper enforce)
endif()
cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# memcpy depends on device_context, here add deps individually for
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/gloo_context.h"
namespace paddle {
namespace platform {
#if defined(PADDLE_WITH_GLOO)
void GlooParallelContext::Init() {
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
gloo_ptr->SetRank(strategy_.rank);
gloo_ptr->SetSize(strategy_.rank_num);
gloo_ptr->SetIface(strategy_.iface);
gloo_ptr->SetTimeoutSeconds(strategy_.init_seconds, strategy_.run_seconds);
// gloo_ptr->SetHttpStore(strategy_.ip_address, strategy_.ip_port,
// strategy_.scope);
gloo_ptr->SetHdfsStore(strategy_.hdfs_path, strategy_.hdfs_name,
strategy_.hdfs_ugi);
gloo_ptr->Init();
}
void GlooParallelContext::Barrier() {
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(gloo_ptr->IsInitialized(), true,
paddle::platform::errors::Unavailable(
"Gloo context is not initialized."));
gloo_ptr->Barrier();
}
void GlooParallelContext::ReleaseContext() {
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
if (gloo_ptr->IsInitialized() == true) {
gloo_ptr.reset();
}
}
#endif
} // namespace platform
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
namespace paddle {
namespace platform {
#if defined(PADDLE_WITH_GLOO)
struct GlooParallelStrategy {
int rank{0};
int rank_num{1};
std::string iface;
int init_seconds{9999999};
int run_seconds{9999999};
std::string hdfs_path;
std::string hdfs_name;
std::string hdfs_ugi;
std::string ip_address;
int ip_port;
std::string scope{"worker"};
};
class GlooParallelContext {
public:
explicit GlooParallelContext(const GlooParallelStrategy& strategy)
: strategy_(strategy) {}
virtual ~GlooParallelContext() {}
virtual void Init();
virtual void Barrier();
virtual void ReleaseContext();
protected:
GlooParallelStrategy strategy_;
};
#endif
} // namespace platform
} // namespace paddle
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper metrics prune
feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils)
gloo_wrapper gloo_context infer_io_utils)
if (WITH_NCCL)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper)
......@@ -32,7 +32,9 @@ set(PYBIND_SRCS
reader_py.cc
fleet_wrapper_py.cc
gloo_wrapper_py.cc
gloo_context_py.cc
box_helper_py.cc
metrics_py.cc
data_set_py.cc
imperative.cc
ir.cc
......
......@@ -253,6 +253,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_sid", &framework::Dataset::SetMergeBySid,
py::call_guard<py::gil_scoped_release>())
.def("set_shuffle_by_uid", &framework::Dataset::SetShuffleByUid,
py::call_guard<py::gil_scoped_release>())
.def("preprocess_instance", &framework::Dataset::PreprocessInstance,
py::call_guard<py::gil_scoped_release>())
.def("postprocess_instance", &framework::Dataset::PostprocessInstance,
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/gloo_context_py.h"
#include <Python.h>
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/platform/gloo_context.h"
namespace paddle {
namespace pybind {
namespace py = ::pybind11;
// Bind Methods
void BindGlooContext(py::module *m) {
// define parallel context for gloo
#if defined(PADDLE_WITH_GLOO)
py::class_<platform::GlooParallelStrategy> gloo_parallel_strategy(
*m, "GlooParallelStrategy", "");
gloo_parallel_strategy.def(py::init())
.def_property("rank_num",
[](const platform::GlooParallelStrategy &self) {
return self.rank_num;
},
[](platform::GlooParallelStrategy &self, int nranks) {
self.rank_num = nranks;
})
.def_property(
"rank",
[](const platform::GlooParallelStrategy &self) { return self.rank; },
[](platform::GlooParallelStrategy &self, int rank) {
self.rank = rank;
})
.def_property(
"iface",
[](const platform::GlooParallelStrategy &self) { return self.iface; },
[](platform::GlooParallelStrategy &self, const std::string &iface) {
self.iface = iface;
})
.def_property("init_seconds",
[](const platform::GlooParallelStrategy &self) {
return self.init_seconds;
},
[](platform::GlooParallelStrategy &self, int init_seconds) {
self.init_seconds = init_seconds;
})
.def_property("run_seconds",
[](const platform::GlooParallelStrategy &self) {
return self.run_seconds;
},
[](platform::GlooParallelStrategy &self, int run_seconds) {
self.run_seconds = run_seconds;
})
.def_property(
"ip_address",
[](const platform::GlooParallelStrategy &self) {
return self.ip_address;
},
[](platform::GlooParallelStrategy &self,
const std::string &ip_address) { self.ip_address = ip_address; })
.def_property("ip_port",
[](const platform::GlooParallelStrategy &self) {
return self.ip_port;
},
[](platform::GlooParallelStrategy &self, int ip_port) {
self.ip_port = ip_port;
})
.def_property("hdfs_path",
[](const platform::GlooParallelStrategy &self) {
return self.hdfs_path;
},
[](platform::GlooParallelStrategy &self, const std::string &hdfs_path) {
self.hdfs_path = hdfs_path;
})
.def_property("hdfs_name",
[](const platform::GlooParallelStrategy &self) {
return self.hdfs_name;
},
[](platform::GlooParallelStrategy &self, const std::string &hdfs_name) {
self.hdfs_name = hdfs_name;
})
.def_property("hdfs_ugi",
[](const platform::GlooParallelStrategy &self) {
return self.hdfs_ugi;
},
[](platform::GlooParallelStrategy &self, const std::string &hdfs_ugi) {
self.hdfs_ugi = hdfs_ugi;
});
py::class_<platform::GlooParallelContext> gloo_ctx(*m, "GlooParallelContext");
gloo_ctx.def(py::init<const platform::GlooParallelStrategy &>())
.def("init", [](platform::GlooParallelContext &self) { self.Init(); })
.def("barrier",
[](platform::GlooParallelContext &self) { self.Barrier(); })
.def("release",
[](platform::GlooParallelContext &self) { self.ReleaseContext(); });
#endif
}
} // namespace pybind
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace paddle {
namespace pybind {
void BindGlooContext(pybind11::module* m);
} // namespace pybind
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <memory>
#include <string>
#include <vector>
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/metrics_py.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindMetrics(py::module* m) {
py::class_<framework::Metric, std::shared_ptr<framework::Metric>>(*m,
"Metric")
.def(py::init([]() { return framework::Metric::SetInstance(); }))
.def("init_metric", &framework::Metric::InitMetric,
py::call_guard<py::gil_scoped_release>())
.def("flip_phase", &framework::Metric::FlipPhase,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_msg", &framework::Metric::GetMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_wuauc_metric_msg", &framework::Metric::GetWuAucMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_name_list", &framework::Metric::GetMetricNameList,
py::call_guard<py::gil_scoped_release>());
} // end Metrics
} // end namespace pybind
} // end namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindMetrics(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -59,11 +59,13 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/box_helper_py.h"
#include "paddle/fluid/pybind/metrics_py.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include "paddle/fluid/pybind/gloo_context_py.h"
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
......@@ -2414,11 +2416,15 @@ All parameter, weight, gradient are variables in Paddle.
BindFleetWrapper(&m);
BindGlooWrapper(&m);
BindBoxHelper(&m);
BindMetrics(&m);
#ifdef PADDLE_WITH_BOX_PS
BindBoxWrapper(&m);
#endif
#ifdef PADDLE_WITH_NCCL
BindNCCLWrapper(&m);
#endif
#ifdef PADDLE_WITH_GLOO
BindGlooContext(&m);
#endif
BindGraph(&m);
BindNode(&m);
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .metrics import init_metric # noqa: F401
from .metrics import print_auc # noqa: F401
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import yaml
import paddle.fluid as fluid
import logging
from paddle.distributed.utils import get_logger
__all__ = []
logger = get_logger(logging.INFO, name="metrics")
# read metric config from yaml and init MetricMsg in fleet_wrapper
def init_metric(metric_ptr,
metric_yaml_path,
cmatch_rank_var="",
mask_var="",
uid_var="",
phase=-1,
cmatch_rank_group="",
ignore_rank=False,
bucket_size=1000000):
yaml_fobj = open(metric_yaml_path)
if sys.version.startswith('2.7.13'):
content = yaml.load(yaml_fobj)
else:
content = yaml.load(yaml_fobj, Loader=yaml.FullLoader)
print("yaml metric config: \n")
print(content)
metric_runner_list = content['monitors']
if not metric_runner_list:
metric_runner_list = []
for metric_runner in metric_runner_list:
is_join = metric_runner['phase'] == 'JOINING'
phase = 1 if is_join else 0
if metric_runner['method'] == 'AucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, uid_var, phase, cmatch_rank_group,
ignore_rank, bucket_size)
elif metric_runner['method'] == 'MultiTaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], mask_var, uid_var, phase,
metric_runner['cmatch_group'], ignore_rank, bucket_size)
elif metric_runner['method'] == 'CmatchRankAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], mask_var, uid_var, phase,
metric_runner['cmatch_group'], metric_runner['ignore_rank'],
bucket_size)
elif metric_runner['method'] == 'MaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, metric_runner['mask'], uid_var, phase,
cmatch_rank_group, ignore_rank, bucket_size)
elif metric_runner['method'] == 'CmatchRankMaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], metric_runner['mask'], uid_var,
phase, metric_runner['cmatch_group'],
metric_runner['ignore_rank'], bucket_size)
elif metric_runner['method'] == 'WuAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, metric_runner['uid'], phase,
cmatch_rank_group, ignore_rank, bucket_size)
else:
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, phase, cmatch_rank_group,
ignore_rank, bucket_size)
def print_metric(metric_ptr, name):
"""
print the metric value. Print directly in back-end
"""
if name.find("wuauc") != -1:
metric = metric_ptr.get_wuauc_metric_msg(name)
monitor_msg = "%s: User Count=%.0f INS Count=%.0f UAUC=%.6f WUAUC=%.6f "\
% (name, metric[0], metric[1], metric[4], metric[5])
else:
metric = metric_ptr.get_metric_msg(name)
monitor_msg = "%s: AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "\
"Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f"\
% (name, metric[0], metric[1], metric[2], metric[3], metric[4],
metric[5], metric[6], metric[7])
# logger.info(monitor_msg)
return monitor_msg
def print_auc(metric_ptr, is_day, phase="all"):
"""
print metric according to stage and phase
"""
if is_day is True:
stage = "day"
stage_num = -1
else:
stage = "pass"
stage_num = 1 if phase == "join" else 0
metric_results = []
name_list = metric_ptr.get_metric_name_list(stage_num)
if phase == "all":
for name in name_list:
if name.find(stage) != -1:
metric_results.append(print_metric(metric_ptr, name=name))
else:
for name in name_list:
if name.find(stage) != -1 and name.find(phase) != -1:
metric_results.append(print_metric(metric_ptr, name=name))
return metric_results
......@@ -221,6 +221,23 @@ class DatasetBase(object):
self.dataset.set_filelist(filelist)
self.filelist = filelist
def set_uid_slot(self, uid_slot):
"""
Set user slot name.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_uid_slot('6048')
Args:
set_uid_slot(string): user slot name
"""
multi_slot = self.proto_desc.multi_slot_desc
multi_slot.uid_slot = uid_slot
def set_use_var(self, var_list):
"""
Set Variables which you will use.
......@@ -660,6 +677,23 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = True
self.parse_ins_id = True
def set_shuffle_by_uid(self, enable_shuffle_uid):
"""
Set if Dataset need to shuffle by uid.
Args:
set_shuffle_by_uid(bool): if shuffle according to uid or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_shuffle_by_uid(True)
"""
self.dataset.set_shuffle_by_uid(enable_shuffle_uid)
def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
self.gen_uni_feasigns = generate_uni_feasigns
......
......@@ -596,6 +596,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._hdfs_path = kwargs.get("path", "").rstrip("/")
self._init_timeout_seconds = kwargs.get("init_timeout_seconds", 3600)
self._run_timeout_seconds = kwargs.get("run_timeout_seconds", 9999999)
self._use_metric = kwargs.get("use_metric", False)
ip_port = kwargs.get("http_ip_port", "")
self._http_ip_port = []
self._http_server = None
......@@ -609,6 +610,7 @@ class GeneralRoleMaker(RoleMakerBase):
# set running status of http server
self._http_server_d["running"] = False
self._iface = self.__get_default_iface()
self._iface = "" if self._iface == "lo" else self._iface
# this environment variable can be empty
self._prefix = os.getenv("SYS_JOB_ID", "")
......@@ -670,6 +672,21 @@ class GeneralRoleMaker(RoleMakerBase):
self._hdfs_name, self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
if self._use_metric:
Gloo_strategy = fluid.core.GlooParallelStrategy()
Gloo_strategy.rank = current_id
Gloo_strategy.rank_num = len(worker_endpoints)
# Gloo_strategy.ip_address = self._http_ip_port[0]
# Gloo_strategy.ip_port = int(self._http_ip_port[1])
Gloo_strategy.hdfs_path = self._hdfs_path + "/trainer"
Gloo_strategy.hdfs_name = self._hdfs_name
Gloo_strategy.hdfs_ugi = self._hdfs_ugi
Default_init_timeout_seconds = 3600
Default_run_timeout_seconds = 9999999
Gloo_strategy.init_seconds = Default_init_timeout_seconds
Gloo_strategy.run_seconds = Default_run_timeout_seconds
Gloo = fluid.core.GlooParallelContext(Gloo_strategy)
Gloo.init()
else:
self._all_comm = MockBarrier()
elif training_role == "PSERVER":
......
......@@ -68,6 +68,14 @@ class TestDataset(unittest.TestCase):
self.assertTrue(dataset.parse_ins_id)
self.assertTrue(dataset.parse_content)
def test_shuffle_by_uid(self):
"""
Testcase for shuffle_by_uid.
"""
dataset = paddle.distributed.InMemoryDataset()
dataset._set_uid_slot('6048')
dataset._set_shuffle_by_uid(True)
def test_run_with_dump(self):
"""
Testcase for InMemoryDataset from create to run.
......
......@@ -140,6 +140,7 @@ packages=['paddle',
'paddle.dataset',
'paddle.reader',
'paddle.distributed',
'paddle.distributed.metric',
'paddle.complex',
'paddle.complex.tensor',
'paddle.fluid',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册