未验证 提交 2e6be886 编写于 作者: F Fan Zhang 提交者: GitHub

[PSLIB] Add Metrics Module, Support User-defined Add Metric (#38789)

* [PSLIB] Add Metrics Module, Support User-defined Add Metric

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* modify role_maker

* update CMakeLists.txt
上级 3ef2922b
...@@ -293,7 +293,7 @@ if(WITH_DISTRIBUTE) ...@@ -293,7 +293,7 @@ if(WITH_DISTRIBUTE)
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper metrics lodtensor_printer
lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS} lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto fleet_executor ${BRPC_DEP}) heter_service_proto fleet_executor ${BRPC_DEP})
...@@ -315,7 +315,7 @@ if(WITH_DISTRIBUTE) ...@@ -315,7 +315,7 @@ if(WITH_DISTRIBUTE)
pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
index_sampler index_wrapper sampler index_dataset_proto index_sampler index_wrapper sampler index_dataset_proto
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor) graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
...@@ -336,7 +336,7 @@ if(WITH_DISTRIBUTE) ...@@ -336,7 +336,7 @@ if(WITH_DISTRIBUTE)
ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor fleet_executor) graph_to_program_pass variable_helper timer monitor fleet_executor)
endif() endif()
elseif(WITH_PSLIB) elseif(WITH_PSLIB)
......
...@@ -340,6 +340,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() { ...@@ -340,6 +340,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_id_ = 0; this->thread_id_ = 0;
this->thread_num_ = 1; this->thread_num_ = 1;
this->parse_ins_id_ = false; this->parse_ins_id_ = false;
this->parse_uid_ = false;
this->parse_content_ = false; this->parse_content_ = false;
this->parse_logkey_ = false; this->parse_logkey_ = false;
this->enable_pv_merge_ = false; this->enable_pv_merge_ = false;
...@@ -498,6 +499,11 @@ void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) { ...@@ -498,6 +499,11 @@ void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id; parse_ins_id_ = parse_ins_id;
} }
template <typename T>
void InMemoryDataFeed<T>::SetParseUid(bool parse_uid) {
parse_uid_ = parse_uid;
}
template <typename T> template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() { void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX #ifdef _LINUX
...@@ -1047,6 +1053,7 @@ void MultiSlotInMemoryDataFeed::Init( ...@@ -1047,6 +1053,7 @@ void MultiSlotInMemoryDataFeed::Init(
use_slots_shape_.push_back(local_shape); use_slots_shape_.push_back(local_shape);
} }
} }
uid_slot_ = multi_slot_desc.uid_slot();
feed_vec_.resize(use_slots_.size()); feed_vec_.resize(use_slots_.size());
const int kEstimatedFeasignNumPerSlot = 5; // Magic Number const int kEstimatedFeasignNumPerSlot = 5; // Magic Number
for (size_t i = 0; i < all_slot_num; i++) { for (size_t i = 0; i < all_slot_num; i++) {
...@@ -1160,6 +1167,19 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -1160,6 +1167,19 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"\nWe detect the feasign number of this slot is %d, " "\nWe detect the feasign number of this slot is %d, "
"which is illegal.", "which is illegal.",
str, i, num)); str, i, num));
#ifdef PADDLE_WITH_PSLIB
if (parse_uid_ && all_slots_[i] == uid_slot_) {
PADDLE_ENFORCE(num == 1 && all_slots_type_[i][0] == 'u',
platform::errors::PreconditionNotMet(
"The uid has to be uint64 and single.\n"
"please check this error line: %s",
str));
char* uidptr = endptr;
uint64_t feasign = (uint64_t)strtoull(uidptr, &uidptr, 10);
instance->uid_ = feasign;
}
#endif
if (idx != -1) { if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
......
...@@ -191,6 +191,7 @@ struct Record { ...@@ -191,6 +191,7 @@ struct Record {
uint64_t search_id; uint64_t search_id;
uint32_t rank; uint32_t rank;
uint32_t cmatch; uint32_t cmatch;
std::string uid_;
}; };
inline SlotRecord make_slotrecord() { inline SlotRecord make_slotrecord() {
...@@ -562,6 +563,7 @@ class DataFeed { ...@@ -562,6 +563,7 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) {} virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {} virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseUid(bool parse_uid) {}
virtual void SetParseContent(bool parse_content) {} virtual void SetParseContent(bool parse_content) {}
virtual void SetParseLogKey(bool parse_logkey) {} virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {} virtual void SetEnablePvMerge(bool enable_pv_merge) {}
...@@ -645,6 +647,7 @@ class DataFeed { ...@@ -645,6 +647,7 @@ class DataFeed {
std::vector<std::string> ins_id_vec_; std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_; std::vector<std::string> ins_content_vec_;
platform::Place place_; platform::Place place_;
std::string uid_slot_;
// The input type of pipe reader, 0 for one sample, 1 for one batch // The input type of pipe reader, 0 for one sample, 1 for one batch
int input_type_; int input_type_;
...@@ -709,6 +712,7 @@ class InMemoryDataFeed : public DataFeed { ...@@ -709,6 +712,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadId(int thread_id); virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseUid(bool parse_uid);
virtual void SetParseContent(bool parse_content); virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey); virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge); virtual void SetEnablePvMerge(bool enable_pv_merge);
...@@ -737,6 +741,7 @@ class InMemoryDataFeed : public DataFeed { ...@@ -737,6 +741,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_; int thread_id_;
int thread_num_; int thread_num_;
bool parse_ins_id_; bool parse_ins_id_;
bool parse_uid_;
bool parse_content_; bool parse_content_;
bool parse_logkey_; bool parse_logkey_;
bool enable_pv_merge_; bool enable_pv_merge_;
......
...@@ -22,7 +22,10 @@ message Slot { ...@@ -22,7 +22,10 @@ message Slot {
repeated int32 shape = 5; // we can define N-D Tensor repeated int32 shape = 5; // we can define N-D Tensor
} }
message MultiSlotDesc { repeated Slot slots = 1; } message MultiSlotDesc {
repeated Slot slots = 1;
optional string uid_slot = 2;
}
message DataFeedDesc { message DataFeedDesc {
optional string name = 1; optional string name = 1;
......
...@@ -57,6 +57,8 @@ DatasetImpl<T>::DatasetImpl() { ...@@ -57,6 +57,8 @@ DatasetImpl<T>::DatasetImpl() {
parse_logkey_ = false; parse_logkey_ = false;
preload_thread_num_ = 0; preload_thread_num_ = 0;
global_index_ = 0; global_index_ = 0;
shuffle_by_uid_ = false;
parse_uid_ = false;
} }
// set filelist, file_idx_ will reset to zero. // set filelist, file_idx_ will reset to zero.
...@@ -150,6 +152,12 @@ void DatasetImpl<T>::SetMergeBySid(bool is_merge) { ...@@ -150,6 +152,12 @@ void DatasetImpl<T>::SetMergeBySid(bool is_merge) {
merge_by_sid_ = is_merge; merge_by_sid_ = is_merge;
} }
template <typename T>
void DatasetImpl<T>::SetShuffleByUid(bool enable_shuffle_uid) {
shuffle_by_uid_ = enable_shuffle_uid;
parse_uid_ = true;
}
template <typename T> template <typename T>
void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) { void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge; enable_pv_merge_ = enable_pv_merge;
...@@ -664,11 +672,14 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) { ...@@ -664,11 +672,14 @@ void MultiSlotDataset::GlobalShuffle(int thread_num) {
<< input_channel_->Size(); << input_channel_->Size();
auto get_client_id = [this, fleet_ptr](const Record& data) -> size_t { auto get_client_id = [this, fleet_ptr](const Record& data) -> size_t {
if (!this->merge_by_insid_) { if (this->merge_by_insid_) {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} else {
return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) % return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) %
this->trainer_num_; this->trainer_num_;
} else if (this->shuffle_by_uid_) {
return XXH64(data.uid_.data(), data.uid_.length(), 0) %
this->trainer_num_;
} else {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} }
}; };
...@@ -902,6 +913,7 @@ void DatasetImpl<T>::CreateReaders() { ...@@ -902,6 +913,7 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFeaNum(&total_fea_num_); readers_[i]->SetFeaNum(&total_fea_num_);
readers_[i]->SetFileList(filelist_); readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_); readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseUid(parse_uid_);
readers_[i]->SetParseContent(parse_content_); readers_[i]->SetParseContent(parse_content_);
readers_[i]->SetParseLogKey(parse_logkey_); readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_); readers_[i]->SetEnablePvMerge(enable_pv_merge_);
...@@ -972,6 +984,7 @@ void DatasetImpl<T>::CreatePreLoadReaders() { ...@@ -972,6 +984,7 @@ void DatasetImpl<T>::CreatePreLoadReaders() {
preload_readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_); preload_readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_);
preload_readers_[i]->SetFeaNum(&total_fea_num_); preload_readers_[i]->SetFeaNum(&total_fea_num_);
preload_readers_[i]->SetParseInsId(parse_ins_id_); preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseUid(parse_uid_);
preload_readers_[i]->SetParseContent(parse_content_); preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetParseLogKey(parse_logkey_); preload_readers_[i]->SetParseLogKey(parse_logkey_);
preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_); preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_);
......
...@@ -81,6 +81,7 @@ class Dataset { ...@@ -81,6 +81,7 @@ class Dataset {
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0; virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual bool EnablePvMerge() = 0; virtual bool EnablePvMerge() = 0;
virtual void SetMergeBySid(bool is_merge) = 0; virtual void SetMergeBySid(bool is_merge) = 0;
virtual void SetShuffleByUid(bool enable_shuffle_uid) = 0;
// set merge by ins id // set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0; virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0; virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
...@@ -189,6 +190,7 @@ class DatasetImpl : public Dataset { ...@@ -189,6 +190,7 @@ class DatasetImpl : public Dataset {
virtual void SetParseLogKey(bool parse_logkey); virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge); virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetMergeBySid(bool is_merge); virtual void SetMergeBySid(bool is_merge);
virtual void SetShuffleByUid(bool enable_shuffle_uid);
virtual void SetMergeByInsId(int merge_size); virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns); virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
...@@ -307,6 +309,8 @@ class DatasetImpl : public Dataset { ...@@ -307,6 +309,8 @@ class DatasetImpl : public Dataset {
bool parse_content_; bool parse_content_;
bool parse_logkey_; bool parse_logkey_;
bool merge_by_sid_; bool merge_by_sid_;
bool shuffle_by_uid_;
bool parse_uid_;
bool enable_pv_merge_; // True means to merge pv bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update int current_phase_; // 1 join, 0 update
size_t merge_size_; size_t merge_size_;
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
namespace pten { namespace pten {
...@@ -32,7 +33,6 @@ class Variable; ...@@ -32,7 +33,6 @@ class Variable;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void DownpourWorker::Initialize(const TrainerDesc& desc) { void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param(); param_ = desc.downpour_param();
for (int i = 0; i < param_.sparse_table_size(); ++i) { for (int i = 0; i < param_.sparse_table_size(); ++i) {
...@@ -740,6 +740,23 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -740,6 +740,23 @@ void DownpourWorker::TrainFilesWithProfiler() {
} }
} }
#ifdef PADDLE_WITH_PSLIB
/**
* @brief add auc monitor
*/
inline void AddAucMonitor(const Scope* scope, const platform::Place& place) {
auto metric_ptr = Metric::GetInstance();
auto& metric_list = metric_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(scope, place);
}
}
#endif
void DownpourWorker::TrainFiles() { void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files"; VLOG(3) << "Begin to train files";
platform::SetNumThreads(1); platform::SetNumThreads(1);
...@@ -837,6 +854,13 @@ void DownpourWorker::TrainFiles() { ...@@ -837,6 +854,13 @@ void DownpourWorker::TrainFiles() {
} }
} }
#ifdef PADDLE_WITH_PSLIB
// add data for MetricMsg
if (Metric::GetInstance() != nullptr) {
AddAucMonitor(thread_scope_, place_);
}
#endif
// check inf and nan // check inf and nan
for (std::string& var_name : check_nan_var_names_) { for (std::string& var_name : check_nan_var_names_) {
Variable* var = thread_scope_->FindVar(var_name); Variable* var = thread_scope_->FindVar(var_name);
......
...@@ -42,8 +42,10 @@ endif(WITH_BOX_PS) ...@@ -42,8 +42,10 @@ endif(WITH_BOX_PS)
if(WITH_GLOO) if(WITH_GLOO)
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo) cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
else() else()
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope) cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
endif(WITH_GLOO) endif(WITH_GLOO)
if(WITH_PSLIB) if(WITH_PSLIB)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/fleet/metrics.h"
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/lod_tensor.h"
#if defined(PADDLE_WITH_PSLIB)
namespace paddle {
namespace framework {
std::shared_ptr<Metric> Metric::s_instance_ = nullptr;
void BasicAucCalculator::init(int table_size) {
set_table_size(table_size);
// init CPU memory
for (int i = 0; i < 2; i++) {
_table[i] = std::vector<double>();
}
// reset
reset();
}
void BasicAucCalculator::reset() {
// reset CPU counter
for (int i = 0; i < 2; i++) {
_table[i].assign(_table_size, 0.0);
}
_local_abserr = 0;
_local_sqrerr = 0;
_local_pred = 0;
}
void BasicAucCalculator::add_data(const float* d_pred, const int64_t* d_label,
int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
h_pred.resize(batch_size);
h_label.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_unlock_data(h_pred[i], h_label[i]);
}
}
void BasicAucCalculator::add_unlock_data(double pred, int label) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
int pos = std::min(static_cast<int>(pred * _table_size), _table_size - 1);
PADDLE_ENFORCE_GE(
pos, 0,
platform::errors::PreconditionNotMet(
"pos must be equal or greater than 0, but its value is: %d", pos));
PADDLE_ENFORCE_LT(
pos, _table_size,
platform::errors::PreconditionNotMet(
"pos must be less than table_size, but its value is: %d", pos));
_local_abserr += fabs(pred - label);
_local_sqrerr += (pred - label) * (pred - label);
_local_pred += pred;
++_table[label][pos];
}
// add mask data
void BasicAucCalculator::add_mask_data(const float* d_pred,
const int64_t* d_label,
const int64_t* d_mask, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
thread_local std::vector<int64_t> h_mask;
h_pred.resize(batch_size);
h_label.resize(batch_size);
h_mask.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
memcpy(h_mask.data(), d_mask, sizeof(int64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
if (h_mask[i]) {
add_unlock_data(h_pred[i], h_label[i]);
}
}
}
void BasicAucCalculator::compute() {
#if defined(PADDLE_WITH_GLOO)
double area = 0;
double fp = 0;
double tp = 0;
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (!gloo_wrapper->IsInitialized()) {
VLOG(0) << "GLOO is not inited";
gloo_wrapper->Init();
}
if (gloo_wrapper->Size() > 1) {
auto neg_table = gloo_wrapper->AllReduce(_table[0], "sum");
auto pos_table = gloo_wrapper->AllReduce(_table[1], "sum");
for (int i = _table_size - 1; i >= 0; i--) {
double newfp = fp + neg_table[i];
double newtp = tp + pos_table[i];
area += (newfp - fp) * (tp + newtp) / 2;
fp = newfp;
tp = newtp;
}
} else {
for (int i = _table_size - 1; i >= 0; i--) {
double newfp = fp + _table[0][i];
double newtp = tp + _table[1][i];
area += (newfp - fp) * (tp + newtp) / 2;
fp = newfp;
tp = newtp;
}
}
if (fp < 1e-3 || tp < 1e-3) {
_auc = -0.5; // which means all nonclick or click
} else {
_auc = area / (fp * tp);
}
if (gloo_wrapper->Size() > 1) {
// allreduce sum
std::vector<double> local_abserr_vec(1, _local_abserr);
std::vector<double> local_sqrerr_vec(1, _local_sqrerr);
std::vector<double> local_pred_vec(1, _local_pred);
auto global_abserr_vec = gloo_wrapper->AllReduce(local_abserr_vec, "sum");
auto global_sqrerr_vec = gloo_wrapper->AllReduce(local_sqrerr_vec, "sum");
auto global_pred_vec = gloo_wrapper->AllReduce(local_pred_vec, "sum");
_mae = global_abserr_vec[0] / (fp + tp);
_rmse = sqrt(global_sqrerr_vec[0] / (fp + tp));
_predicted_ctr = global_pred_vec[0] / (fp + tp);
} else {
_mae = _local_abserr / (fp + tp);
_rmse = sqrt(_local_sqrerr / (fp + tp));
_predicted_ctr = _local_pred / (fp + tp);
}
_actual_ctr = tp / (fp + tp);
_size = fp + tp;
calculate_bucket_error();
#endif
}
void BasicAucCalculator::calculate_bucket_error() {
#if defined(PADDLE_WITH_GLOO)
double last_ctr = -1;
double impression_sum = 0;
double ctr_sum = 0.0;
double click_sum = 0.0;
double error_sum = 0.0;
double error_count = 0;
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (gloo_wrapper->Size() > 1) {
auto neg_table = gloo_wrapper->AllReduce(_table[0], "sum");
auto pos_table = gloo_wrapper->AllReduce(_table[1], "sum");
for (int i = 0; i < _table_size; i++) {
double click = pos_table[i];
double show = neg_table[i] + pos_table[i];
double ctr = static_cast<double>(i) / _table_size;
if (fabs(ctr - last_ctr) > kMaxSpan) {
last_ctr = ctr;
impression_sum = 0.0;
ctr_sum = 0.0;
click_sum = 0.0;
}
impression_sum += show;
ctr_sum += ctr * show;
click_sum += click;
double adjust_ctr = ctr_sum / impression_sum;
double relative_error =
sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum));
if (relative_error < kRelativeErrorBound) {
double actual_ctr = click_sum / impression_sum;
double relative_ctr_error = fabs(actual_ctr / adjust_ctr - 1);
error_sum += relative_ctr_error * impression_sum;
error_count += impression_sum;
last_ctr = -1;
}
}
} else {
double* table[2] = {&_table[0][0], &_table[1][0]};
for (int i = 0; i < _table_size; i++) {
double click = table[1][i];
double show = table[0][i] + table[1][i];
double ctr = static_cast<double>(i) / _table_size;
if (fabs(ctr - last_ctr) > kMaxSpan) {
last_ctr = ctr;
impression_sum = 0.0;
ctr_sum = 0.0;
click_sum = 0.0;
}
impression_sum += show;
ctr_sum += ctr * show;
click_sum += click;
double adjust_ctr = ctr_sum / impression_sum;
double relative_error =
sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum));
if (relative_error < kRelativeErrorBound) {
double actual_ctr = click_sum / impression_sum;
double relative_ctr_error = fabs(actual_ctr / adjust_ctr - 1);
error_sum += relative_ctr_error * impression_sum;
error_count += impression_sum;
last_ctr = -1;
}
}
}
_bucket_error = error_count > 0 ? error_sum / error_count : 0.0;
#endif
}
void BasicAucCalculator::reset_records() {
// reset wuauc_records_
wuauc_records_.clear();
_user_cnt = 0;
_size = 0;
_uauc = 0;
_wuauc = 0;
}
// add uid data
void BasicAucCalculator::add_uid_data(const float* d_pred,
const int64_t* d_label,
const int64_t* d_uid, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
thread_local std::vector<uint64_t> h_uid;
h_pred.resize(batch_size);
h_label.resize(batch_size);
h_uid.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
memcpy(h_uid.data(), d_uid, sizeof(uint64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_uid_unlock_data(h_pred[i], h_label[i], static_cast<uint64_t>(h_uid[i]));
}
}
void BasicAucCalculator::add_uid_unlock_data(double pred, int label,
uint64_t uid) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
WuaucRecord record;
record.uid_ = uid;
record.label_ = label;
record.pred_ = pred;
wuauc_records_.emplace_back(std::move(record));
}
void BasicAucCalculator::computeWuAuc() {
std::sort(wuauc_records_.begin(), wuauc_records_.end(),
[](const WuaucRecord& lhs, const WuaucRecord& rhs) {
if (lhs.uid_ == rhs.uid_) {
if (lhs.pred_ == rhs.pred_) {
return lhs.label_ < rhs.label_;
} else {
return lhs.pred_ > rhs.pred_;
}
} else {
return lhs.uid_ > rhs.uid_;
}
});
WuaucRocData roc_data;
uint64_t prev_uid = 0;
size_t prev_pos = 0;
for (size_t i = 0; i < wuauc_records_.size(); ++i) {
if (wuauc_records_[i].uid_ != prev_uid) {
std::vector<WuaucRecord> single_user_recs(
wuauc_records_.begin() + prev_pos, wuauc_records_.begin() + i);
roc_data = computeSingelUserAuc(single_user_recs);
if (roc_data.auc_ != -1) {
double ins_num = (roc_data.tp_ + roc_data.fp_);
_user_cnt += 1;
_size += ins_num;
_uauc += roc_data.auc_;
_wuauc += roc_data.auc_ * ins_num;
}
prev_uid = wuauc_records_[i].uid_;
prev_pos = i;
}
}
std::vector<WuaucRecord> single_user_recs(wuauc_records_.begin() + prev_pos,
wuauc_records_.end());
roc_data = computeSingelUserAuc(single_user_recs);
if (roc_data.auc_ != -1) {
double ins_num = (roc_data.tp_ + roc_data.fp_);
_user_cnt += 1;
_size += ins_num;
_uauc += roc_data.auc_;
_wuauc += roc_data.auc_ * ins_num;
}
}
BasicAucCalculator::WuaucRocData BasicAucCalculator::computeSingelUserAuc(
const std::vector<WuaucRecord>& records) {
double tp = 0.0;
double fp = 0.0;
double newtp = 0.0;
double newfp = 0.0;
double area = 0.0;
double auc = -1;
size_t i = 0;
while (i < records.size()) {
newtp = tp;
newfp = fp;
if (records[i].label_ == 1) {
newtp += 1;
} else {
newfp += 1;
}
// check i+1
while (i < records.size() - 1 && records[i].pred_ == records[i + 1].pred_) {
if (records[i + 1].label_ == 1) {
newtp += 1;
} else {
newfp += 1;
}
i += 1;
}
area += (newfp - fp) * (tp + newtp) / 2.0;
tp = newtp;
fp = newfp;
i += 1;
}
if (tp > 0 && fp > 0) {
auc = area / (fp * tp + 1e-9);
} else {
auc = -1;
}
return {tp, fp, auc};
}
} // namespace framework
} // namespace paddle
#endif
此差异已折叠。
set(PYBIND_DEPS init pybind python proto_desc memory executor fleet_wrapper box_wrapper prune set(PYBIND_DEPS init pybind python proto_desc memory executor fleet_wrapper box_wrapper metrics prune
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
...@@ -63,6 +63,7 @@ set(PYBIND_SRCS ...@@ -63,6 +63,7 @@ set(PYBIND_SRCS
ps_gpu_wrapper_py.cc ps_gpu_wrapper_py.cc
gloo_wrapper_py.cc gloo_wrapper_py.cc
box_helper_py.cc box_helper_py.cc
metrics_py.cc
data_set_py.cc data_set_py.cc
imperative.cc imperative.cc
ir.cc ir.cc
......
...@@ -271,6 +271,8 @@ void BindDataset(py::module *m) { ...@@ -271,6 +271,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_sid", &framework::Dataset::SetMergeBySid, .def("set_merge_by_sid", &framework::Dataset::SetMergeBySid,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_shuffle_by_uid", &framework::Dataset::SetShuffleByUid,
py::call_guard<py::gil_scoped_release>())
.def("preprocess_instance", &framework::Dataset::PreprocessInstance, .def("preprocess_instance", &framework::Dataset::PreprocessInstance,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("postprocess_instance", &framework::Dataset::PostprocessInstance, .def("postprocess_instance", &framework::Dataset::PostprocessInstance,
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <memory>
#include <string>
#include <vector>
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/metrics_py.h"
namespace py = pybind11;
#if defined(PADDLE_WITH_PSLIB)
namespace paddle {
namespace pybind {
void BindMetrics(py::module* m) {
py::class_<framework::Metric, std::shared_ptr<framework::Metric>>(*m,
"Metric")
.def(py::init([]() { return framework::Metric::SetInstance(); }))
.def("init_metric", &framework::Metric::InitMetric,
py::call_guard<py::gil_scoped_release>())
.def("flip_phase", &framework::Metric::FlipPhase,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_msg", &framework::Metric::GetMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_wuauc_metric_msg", &framework::Metric::GetWuAucMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_name_list", &framework::Metric::GetMetricNameList,
py::call_guard<py::gil_scoped_release>());
} // end Metrics
} // end namespace pybind
} // end namespace paddle
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
#if defined(PADDLE_WITH_PSLIB)
namespace paddle {
namespace pybind {
void BindMetrics(py::module* m);
} // namespace pybind
} // namespace paddle
#endif
...@@ -100,6 +100,7 @@ limitations under the License. */ ...@@ -100,6 +100,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h" #include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/ir.h"
#include "paddle/fluid/pybind/metrics_py.h"
#include "paddle/fluid/pybind/ps_gpu_wrapper_py.h" #include "paddle/fluid/pybind/ps_gpu_wrapper_py.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h" #include "paddle/fluid/pybind/pybind_boost_headers.h"
...@@ -3678,6 +3679,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -3678,6 +3679,7 @@ All parameter, weight, gradient are variables in Paddle.
#if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS) #if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS)
BindHeterWrapper(&m); BindHeterWrapper(&m);
BindMetrics(&m);
#endif #endif
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
BindPSGPUWrapper(&m); BindPSGPUWrapper(&m);
......
...@@ -141,6 +141,23 @@ class DatasetBase(object): ...@@ -141,6 +141,23 @@ class DatasetBase(object):
def _set_input_type(self, input_type): def _set_input_type(self, input_type):
self.proto_desc.input_type = input_type self.proto_desc.input_type = input_type
def _set_uid_slot(self, uid_slot):
"""
Set user slot name.
Examples:
.. code-block:: python
import paddle
dataset = paddle.distributed.fleet.DatasetBase()
dataset._set_uid_slot('6048')
Args:
set_uid_slot(string): user slot name
"""
multi_slot = self.proto_desc.multi_slot_desc
multi_slot.uid_slot = uid_slot
def _set_use_var(self, var_list): def _set_use_var(self, var_list):
""" """
Set Variables which you will use. Set Variables which you will use.
...@@ -738,6 +755,23 @@ class InMemoryDataset(DatasetBase): ...@@ -738,6 +755,23 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = True self.merge_by_lineid = True
self.parse_ins_id = True self.parse_ins_id = True
def _set_shuffle_by_uid(self, enable_shuffle_uid):
"""
Set if Dataset need to shuffle by uid.
Args:
set_shuffle_by_uid(bool): if shuffle according to uid or not
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
dataset = paddle.distributed.InMemoryDataset()
dataset._set_shuffle_by_uid(True)
"""
self.dataset.set_shuffle_by_uid(enable_shuffle_uid)
def _set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num): def _set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
self.dataset.set_generate_unique_feasigns(generate_uni_feasigns) self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
self.gen_uni_feasigns = generate_uni_feasigns self.gen_uni_feasigns = generate_uni_feasigns
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .metrics import init_metric # noqa: F401
from .metrics import print_auc # noqa: F401
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import yaml
import paddle.fluid as fluid
import logging
from paddle.distributed.utils import get_logger
__all__ = []
logger = get_logger(logging.INFO, name="metrics")
# read metric config from yaml and init MetricMsg in fleet_wrapper
def init_metric(metric_ptr,
metric_yaml_path,
cmatch_rank_var="",
mask_var="",
uid_var="",
phase=-1,
cmatch_rank_group="",
ignore_rank=False,
bucket_size=1000000):
yaml_fobj = open(metric_yaml_path)
if sys.version.startswith('2.7.13'):
content = yaml.load(yaml_fobj)
else:
content = yaml.load(yaml_fobj, Loader=yaml.FullLoader)
print("yaml metric config: \n")
print(content)
metric_runner_list = content['monitors']
if not metric_runner_list:
metric_runner_list = []
for metric_runner in metric_runner_list:
is_join = metric_runner['phase'] == 'JOINING'
phase = 1 if is_join else 0
if metric_runner['method'] == 'AucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, uid_var, phase, cmatch_rank_group,
ignore_rank, bucket_size)
elif metric_runner['method'] == 'MultiTaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], mask_var, uid_var, phase,
metric_runner['cmatch_group'], ignore_rank, bucket_size)
elif metric_runner['method'] == 'CmatchRankAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], mask_var, uid_var, phase,
metric_runner['cmatch_group'], metric_runner['ignore_rank'],
bucket_size)
elif metric_runner['method'] == 'MaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, metric_runner['mask'], uid_var, phase,
cmatch_rank_group, ignore_rank, bucket_size)
elif metric_runner['method'] == 'CmatchRankMaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], metric_runner['mask'], uid_var,
phase, metric_runner['cmatch_group'],
metric_runner['ignore_rank'], bucket_size)
elif metric_runner['method'] == 'WuAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, metric_runner['uid'], phase,
cmatch_rank_group, ignore_rank, bucket_size)
else:
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, phase, cmatch_rank_group,
ignore_rank, bucket_size)
def print_metric(metric_ptr, name):
"""
print the metric value. Print directly in back-end
"""
if name.find("wuauc") != -1:
metric = metric_ptr.get_wuauc_metric_msg(name)
monitor_msg = "%s: User Count=%.0f INS Count=%.0f UAUC=%.6f WUAUC=%.6f "\
% (name, metric[0], metric[1], metric[4], metric[5])
else:
metric = metric_ptr.get_metric_msg(name)
monitor_msg = "%s: AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "\
"Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f"\
% (name, metric[0], metric[1], metric[2], metric[3], metric[4],
metric[5], metric[6], metric[7])
# logger.info(monitor_msg)
return monitor_msg
def print_auc(metric_ptr, is_day, phase="all"):
"""
print metric according to stage and phase
"""
if is_day is True:
stage = "day"
stage_num = -1
else:
stage = "pass"
stage_num = 1 if phase == "join" else 0
metric_results = []
name_list = metric_ptr.get_metric_name_list(stage_num)
if phase == "all":
for name in name_list:
if name.find(stage) != -1:
metric_results.append(print_metric(metric_ptr, name=name))
else:
for name in name_list:
if name.find(stage) != -1 and name.find(phase) != -1:
metric_results.append(print_metric(metric_ptr, name=name))
return metric_results
...@@ -598,6 +598,7 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -598,6 +598,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._hdfs_path = kwargs.get("path", "").rstrip("/") self._hdfs_path = kwargs.get("path", "").rstrip("/")
self._init_timeout_seconds = kwargs.get("init_timeout_seconds", 3600) self._init_timeout_seconds = kwargs.get("init_timeout_seconds", 3600)
self._run_timeout_seconds = kwargs.get("run_timeout_seconds", 9999999) self._run_timeout_seconds = kwargs.get("run_timeout_seconds", 9999999)
self._use_metric = kwargs.get("use_metric", False)
ip_port = kwargs.get("http_ip_port", "") ip_port = kwargs.get("http_ip_port", "")
self._use_ps_gpu = kwargs.get("use_ps_gpu", False) self._use_ps_gpu = kwargs.get("use_ps_gpu", False)
self._http_ip_port = [] self._http_ip_port = []
...@@ -668,7 +669,7 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -668,7 +669,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._hdfs_name, self._hdfs_ugi) self._hdfs_name, self._hdfs_ugi)
gloo.init() gloo.init()
self._node_type_comm = gloo self._node_type_comm = gloo
if self._use_ps_gpu: if self._use_ps_gpu or self._use_metric:
Gloo_strategy = fluid.core.GlooParallelStrategy() Gloo_strategy = fluid.core.GlooParallelStrategy()
Gloo_strategy.rank = current_id Gloo_strategy.rank = current_id
Gloo_strategy.rank_num = len(worker_endpoints) Gloo_strategy.rank_num = len(worker_endpoints)
......
...@@ -70,6 +70,14 @@ class TestDataset(unittest.TestCase): ...@@ -70,6 +70,14 @@ class TestDataset(unittest.TestCase):
self.assertTrue(dataset.parse_content) self.assertTrue(dataset.parse_content)
self.assertEqual(dataset.trainer_num, 1) self.assertEqual(dataset.trainer_num, 1)
def test_shuffle_by_uid(self):
"""
Testcase for shuffle_by_uid.
"""
dataset = paddle.distributed.InMemoryDataset()
dataset._set_uid_slot('6048')
dataset._set_shuffle_by_uid(True)
def test_run_with_dump(self): def test_run_with_dump(self):
""" """
Testcase for InMemoryDataset from create to run. Testcase for InMemoryDataset from create to run.
......
...@@ -269,6 +269,7 @@ packages=['paddle', ...@@ -269,6 +269,7 @@ packages=['paddle',
'paddle.dataset', 'paddle.dataset',
'paddle.reader', 'paddle.reader',
'paddle.distributed', 'paddle.distributed',
'paddle.distributed.metric',
'paddle.incubate', 'paddle.incubate',
'paddle.incubate.optimizer', 'paddle.incubate.optimizer',
'paddle.incubate.checkpoint', 'paddle.incubate.checkpoint',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册