未验证 提交 2e6be886 编写于 作者: F Fan Zhang 提交者: GitHub

[PSLIB] Add Metrics Module, Support User-defined Add Metric (#38789)

* [PSLIB] Add Metrics Module, Support User-defined Add Metric

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* modify role_maker

* update CMakeLists.txt
上级 3ef2922b
...@@ -293,7 +293,7 @@ if(WITH_DISTRIBUTE) ...@@ -293,7 +293,7 @@ if(WITH_DISTRIBUTE)
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper metrics lodtensor_printer
lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS} lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto fleet_executor ${BRPC_DEP}) heter_service_proto fleet_executor ${BRPC_DEP})
...@@ -315,7 +315,7 @@ if(WITH_DISTRIBUTE) ...@@ -315,7 +315,7 @@ if(WITH_DISTRIBUTE)
pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
index_sampler index_wrapper sampler index_dataset_proto index_sampler index_wrapper sampler index_dataset_proto
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor) graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
...@@ -336,7 +336,7 @@ if(WITH_DISTRIBUTE) ...@@ -336,7 +336,7 @@ if(WITH_DISTRIBUTE)
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor fleet_executor) graph_to_program_pass variable_helper timer monitor fleet_executor)
endif() endif()
elseif(WITH_PSLIB) elseif(WITH_PSLIB)
......
...@@ -340,6 +340,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() { ...@@ -340,6 +340,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_id_ = 0; this->thread_id_ = 0;
this->thread_num_ = 1; this->thread_num_ = 1;
this->parse_ins_id_ = false; this->parse_ins_id_ = false;
this->parse_uid_ = false;
this->parse_content_ = false; this->parse_content_ = false;
this->parse_logkey_ = false; this->parse_logkey_ = false;
this->enable_pv_merge_ = false; this->enable_pv_merge_ = false;
...@@ -498,6 +499,11 @@ void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) { ...@@ -498,6 +499,11 @@ void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id; parse_ins_id_ = parse_ins_id;
} }
template <typename T>
void InMemoryDataFeed<T>::SetParseUid(bool parse_uid) {
parse_uid_ = parse_uid;
}
template <typename T> template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() { void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX #ifdef _LINUX
...@@ -1047,6 +1053,7 @@ void MultiSlotInMemoryDataFeed::Init( ...@@ -1047,6 +1053,7 @@ void MultiSlotInMemoryDataFeed::Init(
use_slots_shape_.push_back(local_shape); use_slots_shape_.push_back(local_shape);
} }
} }
uid_slot_ = multi_slot_desc.uid_slot();
feed_vec_.resize(use_slots_.size()); feed_vec_.resize(use_slots_.size());
const int kEstimatedFeasignNumPerSlot = 5; // Magic Number const int kEstimatedFeasignNumPerSlot = 5; // Magic Number
for (size_t i = 0; i < all_slot_num; i++) { for (size_t i = 0; i < all_slot_num; i++) {
...@@ -1160,6 +1167,19 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -1160,6 +1167,19 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"\nWe detect the feasign number of this slot is %d, " "\nWe detect the feasign number of this slot is %d, "
"which is illegal.", "which is illegal.",
str, i, num)); str, i, num));
#ifdef PADDLE_WITH_PSLIB
if (parse_uid_ && all_slots_[i] == uid_slot_) {
PADDLE_ENFORCE(num == 1 && all_slots_type_[i][0] == 'u',
platform::errors::PreconditionNotMet(
"The uid has to be uint64 and single.\n"
"please check this error line: %s",
str));
char* uidptr = endptr;
uint64_t feasign = (uint64_t)strtoull(uidptr, &uidptr, 10);
instance->uid_ = feasign;
}
#endif
if (idx != -1) { if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
......
...@@ -191,6 +191,7 @@ struct Record { ...@@ -191,6 +191,7 @@ struct Record {
uint64_t search_id; uint64_t search_id;
uint32_t rank; uint32_t rank;
uint32_t cmatch; uint32_t cmatch;
std::string uid_;
}; };
inline SlotRecord make_slotrecord() { inline SlotRecord make_slotrecord() {
...@@ -562,6 +563,7 @@ class DataFeed { ...@@ -562,6 +563,7 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) {} virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {} virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseUid(bool parse_uid) {}
virtual void SetParseContent(bool parse_content) {} virtual void SetParseContent(bool parse_content) {}
virtual void SetParseLogKey(bool parse_logkey) {} virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {} virtual void SetEnablePvMerge(bool enable_pv_merge) {}
...@@ -645,6 +647,7 @@ class DataFeed { ...@@ -645,6 +647,7 @@ class DataFeed {
std::vector<std::string> ins_id_vec_; std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_; std::vector<std::string> ins_content_vec_;
platform::Place place_; platform::Place place_;
std::string uid_slot_;
// The input type of pipe reader, 0 for one sample, 1 for one batch // The input type of pipe reader, 0 for one sample, 1 for one batch
int input_type_; int input_type_;
...@@ -709,6 +712,7 @@ class InMemoryDataFeed : public DataFeed { ...@@ -709,6 +712,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadId(int thread_id); virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseUid(bool parse_uid);
virtual void SetParseContent(bool parse_content); virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey); virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge); virtual void SetEnablePvMerge(bool enable_pv_merge);
...@@ -737,6 +741,7 @@ class InMemoryDataFeed : public DataFeed { ...@@ -737,6 +741,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_; int thread_id_;
int thread_num_; int thread_num_;
bool parse_ins_id_; bool parse_ins_id_;
bool parse_uid_;
bool parse_content_; bool parse_content_;
bool parse_logkey_; bool parse_logkey_;
bool enable_pv_merge_; bool enable_pv_merge_;
......
...@@ -22,7 +22,10 @@ message Slot { ...@@ -22,7 +22,10 @@ message Slot {
repeated int32 shape = 5; // we can define N-D Tensor repeated int32 shape = 5; // we can define N-D Tensor
} }
message MultiSlotDesc { repeated Slot slots = 1; } message MultiSlotDesc {
repeated Slot slots = 1;
optional string uid_slot = 2;
}
message DataFeedDesc { message DataFeedDesc {
optional string name = 1; optional string name = 1;
......
...@@ -57,6 +57,8 @@ DatasetImpl<T>::DatasetImpl() { ...@@ -57,6 +57,8 @@ DatasetImpl<T>::DatasetImpl() {
parse_logkey_ = false; parse_logkey_ = false;
preload_thread_num_ = 0; preload_thread_num_ = 0;
global_index_ = 0; global_index_ = 0;
shuffle_by_uid_ = false;
parse_uid_ = false;
} }
// set filelist, file_idx_ will reset to zero. // set filelist, file_idx_ will reset to zero.
...@@ -150,6 +152,12 @@ void DatasetImpl<T>::SetMergeBySid(bool is_merge) { ...@@ -150,6 +152,12 @@ void DatasetImpl<T>::SetMergeBySid(bool is_merge) {
merge_by_sid_ = is_merge; merge_by_sid_ = is_merge;
} }
template <typename T>
void DatasetImpl<T>::SetShuffleByUid(bool enable_shuffle_uid) {
shuffle_by_uid_ = enable_shuffle_uid;
parse_uid_ = true;
}
template <typename T> template <typename T>
void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) { void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge; enable_pv_merge_ = enable_pv_merge;
...@@ -664,11 +672,14 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) { ...@@ -664,11 +672,14 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
<< input_channel_->Size(); << input_channel_->Size();
auto get_client_id = [this, fleet_ptr](const Record& data) -> size_t { auto get_client_id = [this, fleet_ptr](const Record& data) -> size_t {
if (!this->merge_by_insid_) { if (this->merge_by_insid_) {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} else {
return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) % return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) %
this->trainer_num_; this->trainer_num_;
} else if (this->shuffle_by_uid_) {
return XXH64(data.uid_.data(), data.uid_.length(), 0) %
this->trainer_num_;
} else {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} }
}; };
...@@ -902,6 +913,7 @@ void DatasetImpl<T>::CreateReaders() { ...@@ -902,6 +913,7 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFeaNum(&total_fea_num_); readers_[i]->SetFeaNum(&total_fea_num_);
readers_[i]->SetFileList(filelist_); readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_); readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseUid(parse_uid_);
readers_[i]->SetParseContent(parse_content_); readers_[i]->SetParseContent(parse_content_);
readers_[i]->SetParseLogKey(parse_logkey_); readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_); readers_[i]->SetEnablePvMerge(enable_pv_merge_);
...@@ -972,6 +984,7 @@ void DatasetImpl<T>::CreatePreLoadReaders() { ...@@ -972,6 +984,7 @@ void DatasetImpl<T>::CreatePreLoadReaders() {
preload_readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_); preload_readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_);
preload_readers_[i]->SetFeaNum(&total_fea_num_); preload_readers_[i]->SetFeaNum(&total_fea_num_);
preload_readers_[i]->SetParseInsId(parse_ins_id_); preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseUid(parse_uid_);
preload_readers_[i]->SetParseContent(parse_content_); preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetParseLogKey(parse_logkey_); preload_readers_[i]->SetParseLogKey(parse_logkey_);
preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_); preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_);
......
...@@ -81,6 +81,7 @@ class Dataset { ...@@ -81,6 +81,7 @@ class Dataset {
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0; virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual bool EnablePvMerge() = 0; virtual bool EnablePvMerge() = 0;
virtual void SetMergeBySid(bool is_merge) = 0; virtual void SetMergeBySid(bool is_merge) = 0;
virtual void SetShuffleByUid(bool enable_shuffle_uid) = 0;
// set merge by ins id // set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0; virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0; virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
...@@ -189,6 +190,7 @@ class DatasetImpl : public Dataset { ...@@ -189,6 +190,7 @@ class DatasetImpl : public Dataset {
virtual void SetParseLogKey(bool parse_logkey); virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge); virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetMergeBySid(bool is_merge); virtual void SetMergeBySid(bool is_merge);
virtual void SetShuffleByUid(bool enable_shuffle_uid);
virtual void SetMergeByInsId(int merge_size); virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns); virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
...@@ -307,6 +309,8 @@ class DatasetImpl : public Dataset { ...@@ -307,6 +309,8 @@ class DatasetImpl : public Dataset {
bool parse_content_; bool parse_content_;
bool parse_logkey_; bool parse_logkey_;
bool merge_by_sid_; bool merge_by_sid_;
bool shuffle_by_uid_;
bool parse_uid_;
bool enable_pv_merge_; // True means to merge pv bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update int current_phase_; // 1 join, 0 update
size_t merge_size_; size_t merge_size_;
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
namespace pten { namespace pten {
...@@ -32,7 +33,6 @@ class Variable; ...@@ -32,7 +33,6 @@ class Variable;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void DownpourWorker::Initialize(const TrainerDesc& desc) { void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param(); param_ = desc.downpour_param();
for (int i = 0; i < param_.sparse_table_size(); ++i) { for (int i = 0; i < param_.sparse_table_size(); ++i) {
...@@ -740,6 +740,23 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -740,6 +740,23 @@ void DownpourWorker::TrainFilesWithProfiler() {
} }
} }
#ifdef PADDLE_WITH_PSLIB
/**
* @brief add auc monitor
*/
inline void AddAucMonitor(const Scope* scope, const platform::Place& place) {
auto metric_ptr = Metric::GetInstance();
auto& metric_list = metric_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(scope, place);
}
}
#endif
void DownpourWorker::TrainFiles() { void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files"; VLOG(3) << "Begin to train files";
platform::SetNumThreads(1); platform::SetNumThreads(1);
...@@ -837,6 +854,13 @@ void DownpourWorker::TrainFiles() { ...@@ -837,6 +854,13 @@ void DownpourWorker::TrainFiles() {
} }
} }
#ifdef PADDLE_WITH_PSLIB
// add data for MetricMsg
if (Metric::GetInstance() != nullptr) {
AddAucMonitor(thread_scope_, place_);
}
#endif
// check inf and nan // check inf and nan
for (std::string& var_name : check_nan_var_names_) { for (std::string& var_name : check_nan_var_names_) {
Variable* var = thread_scope_->FindVar(var_name); Variable* var = thread_scope_->FindVar(var_name);
......
...@@ -42,8 +42,10 @@ endif(WITH_BOX_PS) ...@@ -42,8 +42,10 @@ endif(WITH_BOX_PS)
if(WITH_GLOO) if(WITH_GLOO)
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo) cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
else() else()
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope) cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
endif(WITH_GLOO) endif(WITH_GLOO)
if(WITH_PSLIB) if(WITH_PSLIB)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/fleet/metrics.h"
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/lod_tensor.h"
#if defined(PADDLE_WITH_PSLIB)
namespace paddle {
namespace framework {
std::shared_ptr<Metric> Metric::s_instance_ = nullptr;
void BasicAucCalculator::init(int table_size) {
set_table_size(table_size);
// init CPU memory
for (int i = 0; i < 2; i++) {
_table[i] = std::vector<double>();
}
// reset
reset();
}
void BasicAucCalculator::reset() {
// reset CPU counter
for (int i = 0; i < 2; i++) {
_table[i].assign(_table_size, 0.0);
}
_local_abserr = 0;
_local_sqrerr = 0;
_local_pred = 0;
}
void BasicAucCalculator::add_data(const float* d_pred, const int64_t* d_label,
int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
h_pred.resize(batch_size);
h_label.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_unlock_data(h_pred[i], h_label[i]);
}
}
void BasicAucCalculator::add_unlock_data(double pred, int label) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
int pos = std::min(static_cast<int>(pred * _table_size), _table_size - 1);
PADDLE_ENFORCE_GE(
pos, 0,
platform::errors::PreconditionNotMet(
"pos must be equal or greater than 0, but its value is: %d", pos));
PADDLE_ENFORCE_LT(
pos, _table_size,
platform::errors::PreconditionNotMet(
"pos must be less than table_size, but its value is: %d", pos));
_local_abserr += fabs(pred - label);
_local_sqrerr += (pred - label) * (pred - label);
_local_pred += pred;
++_table[label][pos];
}
// add mask data
void BasicAucCalculator::add_mask_data(const float* d_pred,
const int64_t* d_label,
const int64_t* d_mask, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
thread_local std::vector<int64_t> h_mask;
h_pred.resize(batch_size);
h_label.resize(batch_size);
h_mask.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
memcpy(h_mask.data(), d_mask, sizeof(int64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
if (h_mask[i]) {
add_unlock_data(h_pred[i], h_label[i]);
}
}
}
void BasicAucCalculator::compute() {
#if defined(PADDLE_WITH_GLOO)
double area = 0;
double fp = 0;
double tp = 0;
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (!gloo_wrapper->IsInitialized()) {
VLOG(0) << "GLOO is not inited";
gloo_wrapper->Init();
}
if (gloo_wrapper->Size() > 1) {
auto neg_table = gloo_wrapper->AllReduce(_table[0], "sum");
auto pos_table = gloo_wrapper->AllReduce(_table[1], "sum");
for (int i = _table_size - 1; i >= 0; i--) {
double newfp = fp + neg_table[i];
double newtp = tp + pos_table[i];
area += (newfp - fp) * (tp + newtp) / 2;
fp = newfp;
tp = newtp;
}
} else {
for (int i = _table_size - 1; i >= 0; i--) {
double newfp = fp + _table[0][i];
double newtp = tp + _table[1][i];
area += (newfp - fp) * (tp + newtp) / 2;
fp = newfp;
tp = newtp;
}
}
if (fp < 1e-3 || tp < 1e-3) {
_auc = -0.5; // which means all nonclick or click
} else {
_auc = area / (fp * tp);
}
if (gloo_wrapper->Size() > 1) {
// allreduce sum
std::vector<double> local_abserr_vec(1, _local_abserr);
std::vector<double> local_sqrerr_vec(1, _local_sqrerr);
std::vector<double> local_pred_vec(1, _local_pred);
auto global_abserr_vec = gloo_wrapper->AllReduce(local_abserr_vec, "sum");
auto global_sqrerr_vec = gloo_wrapper->AllReduce(local_sqrerr_vec, "sum");
auto global_pred_vec = gloo_wrapper->AllReduce(local_pred_vec, "sum");
_mae = global_abserr_vec[0] / (fp + tp);
_rmse = sqrt(global_sqrerr_vec[0] / (fp + tp));
_predicted_ctr = global_pred_vec[0] / (fp + tp);
} else {
_mae = _local_abserr / (fp + tp);
_rmse = sqrt(_local_sqrerr / (fp + tp));
_predicted_ctr = _local_pred / (fp + tp);
}
_actual_ctr = tp / (fp + tp);
_size = fp + tp;
calculate_bucket_error();
#endif
}
void BasicAucCalculator::calculate_bucket_error() {
#if defined(PADDLE_WITH_GLOO)
double last_ctr = -1;
double impression_sum = 0;
double ctr_sum = 0.0;
double click_sum = 0.0;
double error_sum = 0.0;
double error_count = 0;
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (gloo_wrapper->Size() > 1) {
auto neg_table = gloo_wrapper->AllReduce(_table[0], "sum");
auto pos_table = gloo_wrapper->AllReduce(_table[1], "sum");
for (int i = 0; i < _table_size; i++) {
double click = pos_table[i];
double show = neg_table[i] + pos_table[i];
double ctr = static_cast<double>(i) / _table_size;
if (fabs(ctr - last_ctr) > kMaxSpan) {
last_ctr = ctr;
impression_sum = 0.0;
ctr_sum = 0.0;
click_sum = 0.0;
}
impression_sum += show;
ctr_sum += ctr * show;
click_sum += click;
double adjust_ctr = ctr_sum / impression_sum;
double relative_error =
sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum));
if (relative_error < kRelativeErrorBound) {
double actual_ctr = click_sum / impression_sum;
double relative_ctr_error = fabs(actual_ctr / adjust_ctr - 1);
error_sum += relative_ctr_error * impression_sum;
error_count += impression_sum;
last_ctr = -1;
}
}
} else {
double* table[2] = {&_table[0][0], &_table[1][0]};
for (int i = 0; i < _table_size; i++) {
double click = table[1][i];
double show = table[0][i] + table[1][i];
double ctr = static_cast<double>(i) / _table_size;
if (fabs(ctr - last_ctr) > kMaxSpan) {
last_ctr = ctr;
impression_sum = 0.0;
ctr_sum = 0.0;
click_sum = 0.0;
}
impression_sum += show;
ctr_sum += ctr * show;
click_sum += click;
double adjust_ctr = ctr_sum / impression_sum;
double relative_error =
sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum));
if (relative_error < kRelativeErrorBound) {
double actual_ctr = click_sum / impression_sum;
double relative_ctr_error = fabs(actual_ctr / adjust_ctr - 1);
error_sum += relative_ctr_error * impression_sum;
error_count += impression_sum;
last_ctr = -1;
}
}
}
_bucket_error = error_count > 0 ? error_sum / error_count : 0.0;
#endif
}
void BasicAucCalculator::reset_records() {
// reset wuauc_records_
wuauc_records_.clear();
_user_cnt = 0;
_size = 0;
_uauc = 0;
_wuauc = 0;
}
// add uid data
void BasicAucCalculator::add_uid_data(const float* d_pred,
const int64_t* d_label,
const int64_t* d_uid, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
thread_local std::vector<uint64_t> h_uid;
h_pred.resize(batch_size);
h_label.resize(batch_size);
h_uid.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
memcpy(h_uid.data(), d_uid, sizeof(uint64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_uid_unlock_data(h_pred[i], h_label[i], static_cast<uint64_t>(h_uid[i]));
}
}
void BasicAucCalculator::add_uid_unlock_data(double pred, int label,
uint64_t uid) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
WuaucRecord record;
record.uid_ = uid;
record.label_ = label;
record.pred_ = pred;
wuauc_records_.emplace_back(std::move(record));
}
void BasicAucCalculator::computeWuAuc() {
std::sort(wuauc_records_.begin(), wuauc_records_.end(),
[](const WuaucRecord& lhs, const WuaucRecord& rhs) {
if (lhs.uid_ == rhs.uid_) {
if (lhs.pred_ == rhs.pred_) {
return lhs.label_ < rhs.label_;
} else {
return lhs.pred_ > rhs.pred_;
}
} else {
return lhs.uid_ > rhs.uid_;
}
});
WuaucRocData roc_data;
uint64_t prev_uid = 0;
size_t prev_pos = 0;
for (size_t i = 0; i < wuauc_records_.size(); ++i) {
if (wuauc_records_[i].uid_ != prev_uid) {
std::vector<WuaucRecord> single_user_recs(
wuauc_records_.begin() + prev_pos, wuauc_records_.begin() + i);
roc_data = computeSingelUserAuc(single_user_recs);
if (roc_data.auc_ != -1) {
double ins_num = (roc_data.tp_ + roc_data.fp_);
_user_cnt += 1;
_size += ins_num;
_uauc += roc_data.auc_;
_wuauc += roc_data.auc_ * ins_num;
}
prev_uid = wuauc_records_[i].uid_;
prev_pos = i;
}
}
std::vector<WuaucRecord> single_user_recs(wuauc_records_.begin() + prev_pos,
wuauc_records_.end());
roc_data = computeSingelUserAuc(single_user_recs);
if (roc_data.auc_ != -1) {
double ins_num = (roc_data.tp_ + roc_data.fp_);
_user_cnt += 1;
_size += ins_num;
_uauc += roc_data.auc_;
_wuauc += roc_data.auc_ * ins_num;
}
}
BasicAucCalculator::WuaucRocData BasicAucCalculator::computeSingelUserAuc(
const std::vector<WuaucRecord>& records) {
double tp = 0.0;
double fp = 0.0;
double newtp = 0.0;
double newfp = 0.0;
double area = 0.0;
double auc = -1;
size_t i = 0;
while (i < records.size()) {
newtp = tp;
newfp = fp;
if (records[i].label_ == 1) {
newtp += 1;
} else {
newfp += 1;
}
// check i+1
while (i < records.size() - 1 && records[i].pred_ == records[i + 1].pred_) {
if (records[i + 1].label_ == 1) {
newtp += 1;
} else {
newfp += 1;
}
i += 1;
}
area += (newfp - fp) * (tp + newtp) / 2.0;
tp = newtp;
fp = newfp;
i += 1;
}
if (tp > 0 && fp > 0) {
auc = area / (fp * tp + 1e-9);
} else {
auc = -1;
}
return {tp, fp, auc};
}
} // namespace framework
} // namespace paddle
#endif
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/string/string_helper.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/allreduce.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#if defined(PADDLE_WITH_PSLIB)
namespace paddle {
namespace framework {
class BasicAucCalculator {
public:
BasicAucCalculator() {}
struct WuaucRecord {
uint64_t uid_;
int label_;
float pred_;
};
struct WuaucRocData {
double tp_;
double fp_;
double auc_;
};
void init(int table_size);
void init_wuauc(int table_size);
void reset();
void reset_records();
// add single data in CPU with LOCK, deprecated
void add_unlock_data(double pred, int label);
void add_uid_unlock_data(double pred, int label, uint64_t uid);
// add batch data
void add_data(const float* d_pred, const int64_t* d_label, int batch_size,
const paddle::platform::Place& place);
// add mask data
void add_mask_data(const float* d_pred, const int64_t* d_label,
const int64_t* d_mask, int batch_size,
const paddle::platform::Place& place);
// add uid data
void add_uid_data(const float* d_pred, const int64_t* d_label,
const int64_t* d_uid, int batch_size,
const paddle::platform::Place& place);
void compute();
void computeWuAuc();
WuaucRocData computeSingelUserAuc(const std::vector<WuaucRecord>& records);
int table_size() const { return _table_size; }
double bucket_error() const { return _bucket_error; }
double auc() const { return _auc; }
double uauc() const { return _uauc; }
double wuauc() const { return _wuauc; }
double mae() const { return _mae; }
double actual_ctr() const { return _actual_ctr; }
double predicted_ctr() const { return _predicted_ctr; }
double user_cnt() const { return _user_cnt; }
double size() const { return _size; }
double rmse() const { return _rmse; }
std::unordered_set<uint64_t> uid_keys() const { return _uid_keys; }
// lock and unlock
std::mutex& table_mutex(void) { return _table_mutex; }
private:
void calculate_bucket_error();
protected:
double _local_abserr = 0;
double _local_sqrerr = 0;
double _local_pred = 0;
double _auc = 0;
double _uauc = 0;
double _wuauc = 0;
double _mae = 0;
double _rmse = 0;
double _actual_ctr = 0;
double _predicted_ctr = 0;
double _size;
double _user_cnt = 0;
double _bucket_error = 0;
std::unordered_set<uint64_t> _uid_keys;
private:
void set_table_size(int table_size) { _table_size = table_size; }
int _table_size;
std::vector<double> _table[2];
std::vector<WuaucRecord> wuauc_records_;
static constexpr double kRelativeErrorBound = 0.05;
static constexpr double kMaxSpan = 0.01;
std::mutex _table_mutex;
};
class Metric {
public:
virtual ~Metric() {}
Metric() { fprintf(stdout, "init fleet Metric\n"); }
class MetricMsg {
public:
MetricMsg() {}
MetricMsg(const std::string& label_varname, const std::string& pred_varname,
int metric_phase, int bucket_size = 1000000)
: label_varname_(label_varname),
pred_varname_(pred_varname),
metric_phase_(metric_phase) {
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
virtual ~MetricMsg() {}
int MetricPhase() const { return metric_phase_; }
BasicAucCalculator* GetCalculator() { return calculator; }
// add_data
virtual void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) {
int label_len = 0;
const int64_t* label_data = NULL;
int pred_len = 0;
const float* pred_data = NULL;
get_data<int64_t>(exe_scope, label_varname_, &label_data, &label_len);
get_data<float>(exe_scope, pred_varname_, &pred_data, &pred_len);
PADDLE_ENFORCE_EQ(label_len, pred_len,
platform::errors::PreconditionNotMet(
"the predict data length should be consistent with "
"the label data length"));
calculator->add_data(pred_data, label_data, label_len, place);
}
// get_data
template <class T = float>
static void get_data(const Scope* exe_scope, const std::string& varname,
const T** data, int* len) {
auto* var = exe_scope->FindVar(varname.c_str());
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound(
"Error: var %s is not found in scope.", varname.c_str()));
auto& cpu_tensor = var->Get<LoDTensor>();
*data = cpu_tensor.data<T>();
*len = cpu_tensor.numel();
}
template <class T = float>
static void get_data(const Scope* exe_scope, const std::string& varname,
std::vector<T>* data) {
auto* var = exe_scope->FindVar(varname.c_str());
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound(
"Error: var %s is not found in scope.", varname.c_str()));
auto& cpu_tensor = var->Get<LoDTensor>();
auto* cpu_data = cpu_tensor.data<T>();
auto len = cpu_tensor.numel();
data->resize(len);
memcpy(data->data(), cpu_data, sizeof(T) * len);
}
// parse_cmatch_rank
static inline std::pair<int, int> parse_cmatch_rank(uint64_t x) {
// only consider ignore_rank=True
return std::make_pair(static_cast<int>(x), 0);
// first 32 bit store cmatch and second 32 bit store rank
// return std::make_pair(static_cast<int>(x >> 32),
// static_cast<int>(x & 0xff));
}
protected:
std::string label_varname_;
std::string pred_varname_;
int metric_phase_;
BasicAucCalculator* calculator;
};
class WuAucMetricMsg : public MetricMsg {
public:
WuAucMetricMsg(const std::string& label_varname,
const std::string& pred_varname,
const std::string& uid_varname, int metric_phase,
int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
uid_varname_ = uid_varname;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
}
virtual ~WuAucMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
int label_len = 0;
const int64_t* label_data = NULL;
get_data<int64_t>(exe_scope, label_varname_, &label_data, &label_len);
int pred_len = 0;
const float* pred_data = NULL;
get_data<float>(exe_scope, pred_varname_, &pred_data, &pred_len);
int uid_len = 0;
const int64_t* uid_data = NULL;
get_data<int64_t>(exe_scope, uid_varname_, &uid_data, &uid_len);
PADDLE_ENFORCE_EQ(label_len, uid_len,
platform::errors::PreconditionNotMet(
"the predict data length should be consistent with "
"the label data length"));
auto cal = GetCalculator();
cal->add_uid_data(pred_data, label_data, uid_data, label_len, place);
}
protected:
std::string uid_varname_;
};
class MultiTaskMetricMsg : public MetricMsg {
public:
MultiTaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname_list, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
int bucket_size = 1000000) {
label_varname_ = label_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
const std::vector<std::string>& cur_cmatch_rank =
string::split_string(cmatch_rank, "_");
PADDLE_ENFORCE_EQ(
cur_cmatch_rank.size(), 2,
platform::errors::PreconditionNotMet(
"illegal multitask auc spec: %s", cmatch_rank.c_str()));
cmatch_rank_v.emplace_back(atoi(cur_cmatch_rank[0].c_str()),
atoi(cur_cmatch_rank[1].c_str()));
}
for (const auto& pred_varname : string::split_string(pred_varname_list)) {
pred_v.emplace_back(pred_varname);
}
PADDLE_ENFORCE_EQ(cmatch_rank_v.size(), pred_v.size(),
platform::errors::PreconditionNotMet(
"cmatch_rank's size [%lu] should be equal to pred "
"list's size [%lu], but ther are not equal",
cmatch_rank_v.size(), pred_v.size()));
}
virtual ~MultiTaskMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
std::vector<int64_t> cmatch_rank_data;
get_data<int64_t>(exe_scope, cmatch_rank_varname_, &cmatch_rank_data);
std::vector<int64_t> label_data;
get_data<int64_t>(exe_scope, label_varname_, &label_data);
size_t batch_size = cmatch_rank_data.size();
PADDLE_ENFORCE_EQ(
batch_size, label_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: batch_size[%lu] and label_data[%lu]",
batch_size, label_data.size()));
std::vector<std::vector<float>> pred_data_list(pred_v.size());
for (size_t i = 0; i < pred_v.size(); ++i) {
get_data<float>(exe_scope, pred_v[i], &pred_data_list[i]);
}
for (size_t i = 0; i < pred_data_list.size(); ++i) {
PADDLE_ENFORCE_EQ(
batch_size, pred_data_list[i].size(),
platform::errors::PreconditionNotMet(
"illegal batch size: batch_size[%lu] and pred_data[%lu]",
batch_size, pred_data_list[i].size()));
}
auto cal = GetCalculator();
std::lock_guard<std::mutex> lock(cal->table_mutex());
for (size_t i = 0; i < batch_size; ++i) {
auto cmatch_rank_it =
std::find(cmatch_rank_v.begin(), cmatch_rank_v.end(),
parse_cmatch_rank(cmatch_rank_data[i]));
if (cmatch_rank_it != cmatch_rank_v.end()) {
cal->add_unlock_data(pred_data_list[std::distance(
cmatch_rank_v.begin(), cmatch_rank_it)][i],
label_data[i]);
}
}
}
protected:
std::vector<std::pair<int, int>> cmatch_rank_v;
std::vector<std::string> pred_v;
std::string cmatch_rank_varname_;
};
class CmatchRankMetricMsg : public MetricMsg {
public:
CmatchRankMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
bool ignore_rank = false, int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
metric_phase_ = metric_phase;
ignore_rank_ = ignore_rank;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
if (ignore_rank) { // CmatchAUC
cmatch_rank_v.emplace_back(atoi(cmatch_rank.c_str()), 0);
continue;
}
const std::vector<std::string>& cur_cmatch_rank =
string::split_string(cmatch_rank, "_");
PADDLE_ENFORCE_EQ(
cur_cmatch_rank.size(), 2,
platform::errors::PreconditionNotMet(
"illegal cmatch_rank auc spec: %s", cmatch_rank.c_str()));
cmatch_rank_v.emplace_back(atoi(cur_cmatch_rank[0].c_str()),
atoi(cur_cmatch_rank[1].c_str()));
}
}
virtual ~CmatchRankMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
std::vector<int64_t> cmatch_rank_data;
get_data<int64_t>(exe_scope, cmatch_rank_varname_, &cmatch_rank_data);
std::vector<int64_t> label_data;
get_data<int64_t>(exe_scope, label_varname_, &label_data);
std::vector<float> pred_data;
get_data<float>(exe_scope, pred_varname_, &pred_data);
size_t batch_size = cmatch_rank_data.size();
PADDLE_ENFORCE_EQ(
batch_size, label_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and label_data[%lu]",
batch_size, label_data.size()));
PADDLE_ENFORCE_EQ(
batch_size, pred_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and pred_data[%lu]",
batch_size, pred_data.size()));
auto cal = GetCalculator();
std::lock_guard<std::mutex> lock(cal->table_mutex());
for (size_t i = 0; i < batch_size; ++i) {
const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]);
for (size_t j = 0; j < cmatch_rank_v.size(); ++j) {
bool is_matched = false;
if (ignore_rank_) {
is_matched = cmatch_rank_v[j].first == cur_cmatch_rank.first;
} else {
is_matched = cmatch_rank_v[j] == cur_cmatch_rank;
}
if (is_matched) {
cal->add_unlock_data(pred_data[i], label_data[i]);
break;
}
}
}
}
protected:
std::vector<std::pair<int, int>> cmatch_rank_v;
std::string cmatch_rank_varname_;
bool ignore_rank_;
};
class MaskMetricMsg : public MetricMsg {
public:
MaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int metric_phase,
const std::string& mask_varname, int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
mask_varname_ = mask_varname;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
virtual ~MaskMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
int label_len = 0;
const int64_t* label_data = NULL;
get_data<int64_t>(exe_scope, label_varname_, &label_data, &label_len);
int pred_len = 0;
const float* pred_data = NULL;
get_data<float>(exe_scope, pred_varname_, &pred_data, &pred_len);
int mask_len = 0;
const int64_t* mask_data = NULL;
get_data<int64_t>(exe_scope, mask_varname_, &mask_data, &mask_len);
PADDLE_ENFORCE_EQ(label_len, mask_len,
platform::errors::PreconditionNotMet(
"the predict data length should be consistent with "
"the label data length"));
auto cal = GetCalculator();
cal->add_mask_data(pred_data, label_data, mask_data, label_len, place);
}
protected:
std::string mask_varname_;
};
class CmatchRankMaskMetricMsg : public MetricMsg {
public:
CmatchRankMaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
bool ignore_rank = false,
const std::string& mask_varname = "",
int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
metric_phase_ = metric_phase;
ignore_rank_ = ignore_rank;
mask_varname_ = mask_varname;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
if (ignore_rank) { // CmatchAUC
cmatch_rank_v.emplace_back(atoi(cmatch_rank.c_str()), 0);
continue;
}
const std::vector<std::string>& cur_cmatch_rank =
string::split_string(cmatch_rank, "_");
PADDLE_ENFORCE_EQ(
cur_cmatch_rank.size(), 2,
platform::errors::PreconditionNotMet(
"illegal cmatch_rank auc spec: %s", cmatch_rank.c_str()));
cmatch_rank_v.emplace_back(atoi(cur_cmatch_rank[0].c_str()),
atoi(cur_cmatch_rank[1].c_str()));
}
}
virtual ~CmatchRankMaskMetricMsg() {}
void add_data(const Scope* exe_scope,
const paddle::platform::Place& place) override {
std::vector<int64_t> cmatch_rank_data;
get_data<int64_t>(exe_scope, cmatch_rank_varname_, &cmatch_rank_data);
std::vector<int64_t> label_data;
get_data<int64_t>(exe_scope, label_varname_, &label_data);
std::vector<float> pred_data;
get_data<float>(exe_scope, pred_varname_, &pred_data);
size_t batch_size = cmatch_rank_data.size();
PADDLE_ENFORCE_EQ(
batch_size, label_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and label_data[%lu]",
batch_size, label_data.size()));
PADDLE_ENFORCE_EQ(
batch_size, pred_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and pred_data[%lu]",
batch_size, pred_data.size()));
std::vector<int64_t> mask_data;
if (!mask_varname_.empty()) {
get_data<int64_t>(exe_scope, mask_varname_, &mask_data);
PADDLE_ENFORCE_EQ(
batch_size, mask_data.size(),
platform::errors::PreconditionNotMet(
"illegal batch size: cmatch_rank[%lu] and mask_data[%lu]",
batch_size, mask_data.size()));
}
auto cal = GetCalculator();
std::lock_guard<std::mutex> lock(cal->table_mutex());
for (size_t i = 0; i < batch_size; ++i) {
const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]);
for (size_t j = 0; j < cmatch_rank_v.size(); ++j) {
if (!mask_data.empty() && !mask_data[i]) {
continue;
}
bool is_matched = false;
if (ignore_rank_) {
is_matched = cmatch_rank_v[j].first == cur_cmatch_rank.first;
} else {
is_matched = cmatch_rank_v[j] == cur_cmatch_rank;
}
if (is_matched) {
cal->add_unlock_data(pred_data[i], label_data[i]);
break;
}
}
}
}
protected:
std::vector<std::pair<int, int>> cmatch_rank_v;
std::string cmatch_rank_varname_;
bool ignore_rank_;
std::string mask_varname_;
};
static std::shared_ptr<Metric> GetInstance() {
// PADDLE_ENFORCE_EQ(
// s_instance_ == nullptr, false,
// platform::errors::PreconditionNotMet(
// "GetInstance failed in Metric, you should use SetInstance
// firstly"));
return s_instance_;
}
static std::shared_ptr<Metric> SetInstance() {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (nullptr == s_instance_) {
VLOG(3) << "s_instance_ is null";
s_instance_.reset(new paddle::framework::Metric());
} else {
LOG(WARNING) << "You have already used SetInstance() before";
}
return s_instance_;
}
const std::vector<std::string> GetMetricNameList(
int metric_phase = -1) const {
VLOG(0) << "Want to Get metric phase: " << metric_phase;
if (metric_phase == -1) {
return metric_name_list_;
} else {
std::vector<std::string> ret;
for (const auto& name : metric_name_list_) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(
iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
if (iter->second->MetricPhase() == metric_phase) {
VLOG(3) << name << "'s phase is " << iter->second->MetricPhase()
<< ", we want";
ret.push_back(name);
} else {
VLOG(3) << name << "'s phase is " << iter->second->MetricPhase()
<< ", not we want";
}
}
return ret;
}
}
int Phase() const { return phase_; }
int PhaseNum() const { return phase_num_; }
void FlipPhase() { phase_ = (phase_ + 1) % phase_num_; }
std::map<std::string, MetricMsg*>& GetMetricList() { return metric_lists_; }
void InitMetric(const std::string& method, const std::string& name,
const std::string& label_varname,
const std::string& pred_varname,
const std::string& cmatch_rank_varname,
const std::string& mask_varname,
const std::string& uid_varname, int metric_phase,
const std::string& cmatch_rank_group, bool ignore_rank,
int bucket_size = 1000000) {
if (method == "AucCalculator") {
metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname,
metric_phase, bucket_size));
} else if (method == "MultiTaskAucCalculator") {
metric_lists_.emplace(
name, new MultiTaskMetricMsg(label_varname, pred_varname,
metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size));
} else if (method == "CmatchRankAucCalculator") {
metric_lists_.emplace(name, new CmatchRankMetricMsg(
label_varname, pred_varname, metric_phase,
cmatch_rank_group, cmatch_rank_varname,
ignore_rank, bucket_size));
} else if (method == "MaskAucCalculator") {
metric_lists_.emplace(
name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
mask_varname, bucket_size));
} else if (method == "CmatchRankMaskAucCalculator") {
metric_lists_.emplace(name, new CmatchRankMaskMetricMsg(
label_varname, pred_varname, metric_phase,
cmatch_rank_group, cmatch_rank_varname,
ignore_rank, mask_varname, bucket_size));
} else if (method == "WuAucCalculator") {
metric_lists_.emplace(
name, new WuAucMetricMsg(label_varname, pred_varname, uid_varname,
metric_phase, bucket_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"PSLIB Metrics only support AucCalculator, MultiTaskAucCalculator, "
"CmatchRankAucCalculator, MaskAucCalculator, WuAucCalculator and "
"CmatchRankMaskAucCalculator"));
}
metric_name_list_.emplace_back(name);
}
const std::vector<float> GetMetricMsg(const std::string& name) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
std::vector<float> metric_return_values_(8, 0.0);
auto* auc_cal_ = iter->second->GetCalculator();
auc_cal_->compute();
metric_return_values_[0] = auc_cal_->auc();
metric_return_values_[1] = auc_cal_->bucket_error();
metric_return_values_[2] = auc_cal_->mae();
metric_return_values_[3] = auc_cal_->rmse();
metric_return_values_[4] = auc_cal_->actual_ctr();
metric_return_values_[5] = auc_cal_->predicted_ctr();
metric_return_values_[6] =
auc_cal_->actual_ctr() / auc_cal_->predicted_ctr();
metric_return_values_[7] = auc_cal_->size();
auc_cal_->reset();
return metric_return_values_;
}
const std::vector<float> GetWuAucMetricMsg(const std::string& name) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
VLOG(0) << "begin GetWuAucMetricMsg";
std::vector<float> metric_return_values_(6, 0.0);
auto* auc_cal_ = iter->second->GetCalculator();
auc_cal_->computeWuAuc();
metric_return_values_[0] = auc_cal_->user_cnt();
metric_return_values_[1] = auc_cal_->size();
metric_return_values_[2] = auc_cal_->uauc();
metric_return_values_[3] = auc_cal_->wuauc();
metric_return_values_[4] =
metric_return_values_[2] / (metric_return_values_[0] + 1e-10);
metric_return_values_[5] =
metric_return_values_[3] / (metric_return_values_[1] + 1e-10);
#if defined(PADDLE_WITH_GLOO)
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (gloo_wrapper->Size() > 1) {
auto global_metric_return_values_ =
gloo_wrapper->AllReduce(metric_return_values_, "sum");
global_metric_return_values_[4] =
global_metric_return_values_[2] /
(global_metric_return_values_[0] + 1e-10);
global_metric_return_values_[5] =
global_metric_return_values_[3] /
(global_metric_return_values_[1] + 1e-10);
auc_cal_->reset_records();
return global_metric_return_values_;
} else {
auc_cal_->reset_records();
return metric_return_values_;
}
#else
auc_cal_->reset_records();
return metric_return_values_;
#endif
}
private:
static std::shared_ptr<Metric> s_instance_;
// Metric Related
int phase_ = 1;
int phase_num_ = 2;
std::map<std::string, MetricMsg*> metric_lists_;
std::vector<std::string> metric_name_list_;
};
} // namespace framework
} // namespace paddle
#endif
set(PYBIND_DEPS init pybind python proto_desc memory executor fleet_wrapper box_wrapper prune set(PYBIND_DEPS init pybind python proto_desc memory executor fleet_wrapper box_wrapper metrics prune
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
...@@ -63,6 +63,7 @@ set(PYBIND_SRCS ...@@ -63,6 +63,7 @@ set(PYBIND_SRCS
ps_gpu_wrapper_py.cc ps_gpu_wrapper_py.cc
gloo_wrapper_py.cc gloo_wrapper_py.cc
box_helper_py.cc box_helper_py.cc
metrics_py.cc
data_set_py.cc data_set_py.cc
imperative.cc imperative.cc
ir.cc ir.cc
......
...@@ -271,6 +271,8 @@ void BindDataset(py::module *m) { ...@@ -271,6 +271,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_sid", &framework::Dataset::SetMergeBySid, .def("set_merge_by_sid", &framework::Dataset::SetMergeBySid,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_shuffle_by_uid", &framework::Dataset::SetShuffleByUid,
py::call_guard<py::gil_scoped_release>())
.def("preprocess_instance", &framework::Dataset::PreprocessInstance, .def("preprocess_instance", &framework::Dataset::PreprocessInstance,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("postprocess_instance", &framework::Dataset::PostprocessInstance, .def("postprocess_instance", &framework::Dataset::PostprocessInstance,
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <memory>
#include <string>
#include <vector>
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/metrics_py.h"
namespace py = pybind11;
#if defined(PADDLE_WITH_PSLIB)
namespace paddle {
namespace pybind {
void BindMetrics(py::module* m) {
py::class_<framework::Metric, std::shared_ptr<framework::Metric>>(*m,
"Metric")
.def(py::init([]() { return framework::Metric::SetInstance(); }))
.def("init_metric", &framework::Metric::InitMetric,
py::call_guard<py::gil_scoped_release>())
.def("flip_phase", &framework::Metric::FlipPhase,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_msg", &framework::Metric::GetMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_wuauc_metric_msg", &framework::Metric::GetWuAucMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_name_list", &framework::Metric::GetMetricNameList,
py::call_guard<py::gil_scoped_release>());
} // end Metrics
} // end namespace pybind
} // end namespace paddle
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
#if defined(PADDLE_WITH_PSLIB)
namespace paddle {
namespace pybind {
void BindMetrics(py::module* m);
} // namespace pybind
} // namespace paddle
#endif
...@@ -100,6 +100,7 @@ limitations under the License. */ ...@@ -100,6 +100,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h" #include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/ir.h"
#include "paddle/fluid/pybind/metrics_py.h"
#include "paddle/fluid/pybind/ps_gpu_wrapper_py.h" #include "paddle/fluid/pybind/ps_gpu_wrapper_py.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h" #include "paddle/fluid/pybind/pybind_boost_headers.h"
...@@ -3678,6 +3679,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -3678,6 +3679,7 @@ All parameter, weight, gradient are variables in Paddle.
#if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS) #if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS)
BindHeterWrapper(&m); BindHeterWrapper(&m);
BindMetrics(&m);
#endif #endif
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
BindPSGPUWrapper(&m); BindPSGPUWrapper(&m);
......
...@@ -141,6 +141,23 @@ class DatasetBase(object): ...@@ -141,6 +141,23 @@ class DatasetBase(object):
def _set_input_type(self, input_type): def _set_input_type(self, input_type):
self.proto_desc.input_type = input_type self.proto_desc.input_type = input_type
def _set_uid_slot(self, uid_slot):
"""
Set user slot name.
Examples:
.. code-block:: python
import paddle
dataset = paddle.distributed.fleet.DatasetBase()
dataset._set_uid_slot('6048')
Args:
set_uid_slot(string): user slot name
"""
multi_slot = self.proto_desc.multi_slot_desc
multi_slot.uid_slot = uid_slot
def _set_use_var(self, var_list): def _set_use_var(self, var_list):
""" """
Set Variables which you will use. Set Variables which you will use.
...@@ -738,6 +755,23 @@ class InMemoryDataset(DatasetBase): ...@@ -738,6 +755,23 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = True self.merge_by_lineid = True
self.parse_ins_id = True self.parse_ins_id = True
def _set_shuffle_by_uid(self, enable_shuffle_uid):
"""
Set if Dataset need to shuffle by uid.
Args:
set_shuffle_by_uid(bool): if shuffle according to uid or not
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
dataset = paddle.distributed.InMemoryDataset()
dataset._set_shuffle_by_uid(True)
"""
self.dataset.set_shuffle_by_uid(enable_shuffle_uid)
def _set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num): def _set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
self.dataset.set_generate_unique_feasigns(generate_uni_feasigns) self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
self.gen_uni_feasigns = generate_uni_feasigns self.gen_uni_feasigns = generate_uni_feasigns
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .metrics import init_metric # noqa: F401
from .metrics import print_auc # noqa: F401
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import yaml
import paddle.fluid as fluid
import logging
from paddle.distributed.utils import get_logger
__all__ = []
logger = get_logger(logging.INFO, name="metrics")
# read metric config from yaml and init MetricMsg in fleet_wrapper
def init_metric(metric_ptr,
metric_yaml_path,
cmatch_rank_var="",
mask_var="",
uid_var="",
phase=-1,
cmatch_rank_group="",
ignore_rank=False,
bucket_size=1000000):
yaml_fobj = open(metric_yaml_path)
if sys.version.startswith('2.7.13'):
content = yaml.load(yaml_fobj)
else:
content = yaml.load(yaml_fobj, Loader=yaml.FullLoader)
print("yaml metric config: \n")
print(content)
metric_runner_list = content['monitors']
if not metric_runner_list:
metric_runner_list = []
for metric_runner in metric_runner_list:
is_join = metric_runner['phase'] == 'JOINING'
phase = 1 if is_join else 0
if metric_runner['method'] == 'AucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, uid_var, phase, cmatch_rank_group,
ignore_rank, bucket_size)
elif metric_runner['method'] == 'MultiTaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], mask_var, uid_var, phase,
metric_runner['cmatch_group'], ignore_rank, bucket_size)
elif metric_runner['method'] == 'CmatchRankAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], mask_var, uid_var, phase,
metric_runner['cmatch_group'], metric_runner['ignore_rank'],
bucket_size)
elif metric_runner['method'] == 'MaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, metric_runner['mask'], uid_var, phase,
cmatch_rank_group, ignore_rank, bucket_size)
elif metric_runner['method'] == 'CmatchRankMaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], metric_runner['mask'], uid_var,
phase, metric_runner['cmatch_group'],
metric_runner['ignore_rank'], bucket_size)
elif metric_runner['method'] == 'WuAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, metric_runner['uid'], phase,
cmatch_rank_group, ignore_rank, bucket_size)
else:
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, phase, cmatch_rank_group,
ignore_rank, bucket_size)
def print_metric(metric_ptr, name):
"""
print the metric value. Print directly in back-end
"""
if name.find("wuauc") != -1:
metric = metric_ptr.get_wuauc_metric_msg(name)
monitor_msg = "%s: User Count=%.0f INS Count=%.0f UAUC=%.6f WUAUC=%.6f "\
% (name, metric[0], metric[1], metric[4], metric[5])
else:
metric = metric_ptr.get_metric_msg(name)
monitor_msg = "%s: AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "\
"Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f"\
% (name, metric[0], metric[1], metric[2], metric[3], metric[4],
metric[5], metric[6], metric[7])
# logger.info(monitor_msg)
return monitor_msg
def print_auc(metric_ptr, is_day, phase="all"):
"""
print metric according to stage and phase
"""
if is_day is True:
stage = "day"
stage_num = -1
else:
stage = "pass"
stage_num = 1 if phase == "join" else 0
metric_results = []
name_list = metric_ptr.get_metric_name_list(stage_num)
if phase == "all":
for name in name_list:
if name.find(stage) != -1:
metric_results.append(print_metric(metric_ptr, name=name))
else:
for name in name_list:
if name.find(stage) != -1 and name.find(phase) != -1:
metric_results.append(print_metric(metric_ptr, name=name))
return metric_results
...@@ -598,6 +598,7 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -598,6 +598,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._hdfs_path = kwargs.get("path", "").rstrip("/") self._hdfs_path = kwargs.get("path", "").rstrip("/")
self._init_timeout_seconds = kwargs.get("init_timeout_seconds", 3600) self._init_timeout_seconds = kwargs.get("init_timeout_seconds", 3600)
self._run_timeout_seconds = kwargs.get("run_timeout_seconds", 9999999) self._run_timeout_seconds = kwargs.get("run_timeout_seconds", 9999999)
self._use_metric = kwargs.get("use_metric", False)
ip_port = kwargs.get("http_ip_port", "") ip_port = kwargs.get("http_ip_port", "")
self._use_ps_gpu = kwargs.get("use_ps_gpu", False) self._use_ps_gpu = kwargs.get("use_ps_gpu", False)
self._http_ip_port = [] self._http_ip_port = []
...@@ -668,7 +669,7 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -668,7 +669,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._hdfs_name, self._hdfs_ugi) self._hdfs_name, self._hdfs_ugi)
gloo.init() gloo.init()
self._node_type_comm = gloo self._node_type_comm = gloo
if self._use_ps_gpu: if self._use_ps_gpu or self._use_metric:
Gloo_strategy = fluid.core.GlooParallelStrategy() Gloo_strategy = fluid.core.GlooParallelStrategy()
Gloo_strategy.rank = current_id Gloo_strategy.rank = current_id
Gloo_strategy.rank_num = len(worker_endpoints) Gloo_strategy.rank_num = len(worker_endpoints)
......
...@@ -70,6 +70,14 @@ class TestDataset(unittest.TestCase): ...@@ -70,6 +70,14 @@ class TestDataset(unittest.TestCase):
self.assertTrue(dataset.parse_content) self.assertTrue(dataset.parse_content)
self.assertEqual(dataset.trainer_num, 1) self.assertEqual(dataset.trainer_num, 1)
def test_shuffle_by_uid(self):
"""
Testcase for shuffle_by_uid.
"""
dataset = paddle.distributed.InMemoryDataset()
dataset._set_uid_slot('6048')
dataset._set_shuffle_by_uid(True)
def test_run_with_dump(self): def test_run_with_dump(self):
""" """
Testcase for InMemoryDataset from create to run. Testcase for InMemoryDataset from create to run.
......
...@@ -269,6 +269,7 @@ packages=['paddle', ...@@ -269,6 +269,7 @@ packages=['paddle',
'paddle.dataset', 'paddle.dataset',
'paddle.reader', 'paddle.reader',
'paddle.distributed', 'paddle.distributed',
'paddle.distributed.metric',
'paddle.incubate', 'paddle.incubate',
'paddle.incubate.optimizer', 'paddle.incubate.optimizer',
'paddle.incubate.checkpoint', 'paddle.incubate.checkpoint',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册