未验证 提交 7460a891 编写于 作者: F Fan Zhang 提交者: GitHub

[PSLIB] Add Metrics Module, Support User-defined Add Metric (#38230)

* 12.3 first add metrics module

* add Mask/MultiTask

* add WuAUC

* [PSLIB] Update WuAUC Compute

* [PSLIB] Change WuAUC Compute Mehod

* [PSLIB] Clean WuAUC Compute

* [PSLIB] Clean Metric Module Unused Code

* mv metric instance

* [PSLIB] Add Metrics Module, Support User-defined Add Metric (#38789)

* [PSLIB] Add Metrics Module, Support User-defined Add Metric

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* [PSLIB] Modify According to CI Coverage

* modify role_maker

* update CMakeLists.txt
上级 d3011c75
......@@ -188,7 +188,7 @@ if(WITH_DISTRIBUTE)
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper lodtensor_printer
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper box_wrapper metrics lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......@@ -199,7 +199,7 @@ else()
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper box_wrapper lodtensor_printer feed_fetch_method
lod_rank_table fs shell fleet_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
......
......@@ -232,6 +232,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_id_ = 0;
this->thread_num_ = 1;
this->parse_ins_id_ = false;
this->parse_uid_ = false;
this->parse_content_ = false;
this->parse_logkey_ = false;
this->enable_pv_merge_ = false;
......@@ -362,6 +363,11 @@ void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
}
template <typename T>
void InMemoryDataFeed<T>::SetParseUid(bool parse_uid) {
parse_uid_ = parse_uid;
}
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX
......@@ -838,6 +844,7 @@ void MultiSlotInMemoryDataFeed::Init(
use_slots_shape_.push_back(local_shape);
}
}
uid_slot_ = multi_slot_desc.uid_slot();
feed_vec_.resize(use_slots_.size());
pipe_command_ = data_feed_desc.pipe_command();
finish_init_ = true;
......@@ -929,6 +936,17 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
if (parse_uid_ && all_slots_[i] == uid_slot_) {
PADDLE_ENFORCE(num == 1 && all_slots_type_[i][0] == 'u',
"The uid has to be uint64 and single.\n"
"please check this error line: %s",
str);
char* uidptr = endptr;
uint64_t feasign = (uint64_t)strtoull(uidptr, &uidptr, 10);
instance->uid_ = feasign;
}
if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) {
......
......@@ -92,6 +92,7 @@ struct Record {
uint64_t search_id;
uint32_t rank;
uint32_t cmatch;
std::string uid_;
};
struct PvInstanceObject {
......@@ -157,6 +158,7 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseUid(bool parse_uid) {}
virtual void SetParseContent(bool parse_content) {}
virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
......@@ -232,6 +234,7 @@ class DataFeed {
std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_;
platform::Place place_;
std::string uid_slot_;
};
// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
......@@ -293,6 +296,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseUid(bool parse_uid);
virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
......@@ -307,6 +311,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_;
int thread_num_;
bool parse_ins_id_;
bool parse_uid_;
bool parse_content_;
bool parse_logkey_;
bool enable_pv_merge_;
......@@ -471,7 +476,7 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
for (size_t& x : offset) {
uint64_t t;
ar >> t;
x = (size_t)t;
x = static_cast<size_t>(t);
}
#endif
ar >> ins.MutableFloatData();
......
......@@ -22,7 +22,10 @@ message Slot {
repeated int32 shape = 5; // we can define N-D Tensor
}
message MultiSlotDesc { repeated Slot slots = 1; }
message MultiSlotDesc {
repeated Slot slots = 1;
optional string uid_slot = 2;
}
message DataFeedDesc {
optional string name = 1;
......
......@@ -54,6 +54,8 @@ DatasetImpl<T>::DatasetImpl() {
parse_logkey_ = false;
preload_thread_num_ = 0;
global_index_ = 0;
shuffle_by_uid_ = false;
parse_uid_ = false;
}
// set filelist, file_idx_ will reset to zero.
......@@ -147,6 +149,12 @@ void DatasetImpl<T>::SetMergeBySid(bool is_merge) {
merge_by_sid_ = is_merge;
}
template <typename T>
void DatasetImpl<T>::SetShuffleByUid(bool enable_shuffle_uid) {
shuffle_by_uid_ = enable_shuffle_uid;
parse_uid_ = true;
}
template <typename T>
void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge;
......@@ -386,11 +394,14 @@ void DatasetImpl<T>::GlobalShuffle(int thread_num) {
<< input_channel_->Size();
auto get_client_id = [this, fleet_ptr](const T& data) -> size_t {
if (!this->merge_by_insid_) {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} else {
if (this->merge_by_insid_) {
return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) %
this->trainer_num_;
} else if (this->shuffle_by_uid_) {
return XXH64(data.uid_.data(), data.uid_.length(), 0) %
this->trainer_num_;
} else {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
}
};
......@@ -618,6 +629,7 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFileListIndex(&file_idx_);
readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseUid(parse_uid_);
readers_[i]->SetParseContent(parse_content_);
readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_);
......@@ -686,6 +698,7 @@ void DatasetImpl<T>::CreatePreLoadReaders() {
preload_readers_[i]->SetFileListIndex(&file_idx_);
preload_readers_[i]->SetFileList(filelist_);
preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseUid(parse_uid_);
preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetParseLogKey(parse_logkey_);
preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_);
......
......@@ -68,6 +68,7 @@ class Dataset {
virtual void SetParseLogKey(bool parse_logkey) = 0;
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
virtual void SetShuffleByUid(bool enable_shuffle_uid) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
......@@ -175,6 +176,7 @@ class DatasetImpl : public Dataset {
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetMergeBySid(bool is_merge);
virtual void SetShuffleByUid(bool enable_shuffle_uid);
virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
......@@ -263,6 +265,8 @@ class DatasetImpl : public Dataset {
bool parse_content_;
bool parse_logkey_;
bool merge_by_sid_;
bool shuffle_by_uid_;
bool parse_uid_;
bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update
size_t merge_size_;
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/string/string_helper.h"
......@@ -25,7 +26,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
void DownpourWorker::Initialize(const TrainerDesc& desc) {
param_ = desc.downpour_param();
for (int i = 0; i < param_.sparse_table_size(); ++i) {
......@@ -780,6 +780,21 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}
/**
* @brief add auc monitor
*/
inline void AddAucMonitor(const Scope* scope, const platform::Place& place) {
auto metric_ptr = Metric::GetInstance();
auto& metric_list = metric_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(scope, place);
}
}
void DownpourWorker::TrainFiles() {
VLOG(3) << "Begin to train files";
platform::SetNumThreads(1);
......@@ -877,6 +892,11 @@ void DownpourWorker::TrainFiles() {
}
}
// add data for MetricMsg
if (Metric::GetInstance() != nullptr) {
AddAucMonitor(thread_scope_, place_);
}
// check inf and nan
for (std::string& var_name : check_nan_var_names_) {
Variable* var = thread_scope_->FindVar(var_name);
......
......@@ -15,8 +15,10 @@ endif(WITH_BOX_PS)
if(WITH_GLOO)
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope gloo)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
else()
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope)
cc_library(metrics SRCS metrics.cc DEPS gloo_wrapper)
endif(WITH_GLOO)
cc_test(test_fleet SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell)
......@@ -105,6 +105,11 @@ enum GlooStoreType { HDFS, HTTP };
class GlooWrapper {
public:
static std::shared_ptr<GlooWrapper> GetInstance() {
static auto s_instance = std::make_shared<GlooWrapper>();
return s_instance;
}
GlooWrapper() {}
virtual ~GlooWrapper() {}
......@@ -153,6 +158,8 @@ class GlooWrapper {
#endif
}
bool IsInitialized() { return is_initialized_; }
template <typename T>
std::vector<T> AllReduce(std::vector<T>& sendbuf, // NOLINT
const std::string& mode = "sum") { // NOLINT
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/fleet/metrics.h"
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
std::shared_ptr<Metric> Metric::s_instance_ = nullptr;
void BasicAucCalculator::init(int table_size) {
set_table_size(table_size);
// init CPU memory
for (int i = 0; i < 2; i++) {
_table[i] = std::vector<double>();
}
// reset
reset();
}
void BasicAucCalculator::reset() {
// reset CPU counter
for (int i = 0; i < 2; i++) {
_table[i].assign(_table_size, 0.0);
}
_local_abserr = 0;
_local_sqrerr = 0;
_local_pred = 0;
}
void BasicAucCalculator::add_data(const float* d_pred, const int64_t* d_label,
int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
h_pred.resize(batch_size);
h_label.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_unlock_data(h_pred[i], h_label[i]);
}
}
void BasicAucCalculator::add_unlock_data(double pred, int label) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
int pos = std::min(static_cast<int>(pred * _table_size), _table_size - 1);
PADDLE_ENFORCE_GE(
pos, 0,
platform::errors::PreconditionNotMet(
"pos must be equal or greater than 0, but its value is: %d", pos));
PADDLE_ENFORCE_LT(
pos, _table_size,
platform::errors::PreconditionNotMet(
"pos must be less than table_size, but its value is: %d", pos));
_local_abserr += fabs(pred - label);
_local_sqrerr += (pred - label) * (pred - label);
_local_pred += pred;
++_table[label][pos];
}
// add mask data
void BasicAucCalculator::add_mask_data(const float* d_pred,
const int64_t* d_label,
const int64_t* d_mask, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
thread_local std::vector<int64_t> h_mask;
h_pred.resize(batch_size);
h_label.resize(batch_size);
h_mask.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
memcpy(h_mask.data(), d_mask, sizeof(int64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
if (h_mask[i]) {
add_unlock_data(h_pred[i], h_label[i]);
}
}
}
void BasicAucCalculator::compute() {
#if defined(PADDLE_WITH_GLOO)
double area = 0;
double fp = 0;
double tp = 0;
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (!gloo_wrapper->IsInitialized()) {
VLOG(0) << "GLOO is not inited";
gloo_wrapper->Init();
}
if (gloo_wrapper->Size() > 1) {
auto neg_table = gloo_wrapper->AllReduce(_table[0], "sum");
auto pos_table = gloo_wrapper->AllReduce(_table[1], "sum");
for (int i = _table_size - 1; i >= 0; i--) {
double newfp = fp + neg_table[i];
double newtp = tp + pos_table[i];
area += (newfp - fp) * (tp + newtp) / 2;
fp = newfp;
tp = newtp;
}
} else {
for (int i = _table_size - 1; i >= 0; i--) {
double newfp = fp + _table[0][i];
double newtp = tp + _table[1][i];
area += (newfp - fp) * (tp + newtp) / 2;
fp = newfp;
tp = newtp;
}
}
if (fp < 1e-3 || tp < 1e-3) {
_auc = -0.5; // which means all nonclick or click
} else {
_auc = area / (fp * tp);
}
if (gloo_wrapper->Size() > 1) {
// allreduce sum
std::vector<double> local_abserr_vec(1, _local_abserr);
std::vector<double> local_sqrerr_vec(1, _local_sqrerr);
std::vector<double> local_pred_vec(1, _local_pred);
auto global_abserr_vec = gloo_wrapper->AllReduce(local_abserr_vec, "sum");
auto global_sqrerr_vec = gloo_wrapper->AllReduce(local_sqrerr_vec, "sum");
auto global_pred_vec = gloo_wrapper->AllReduce(local_pred_vec, "sum");
_mae = global_abserr_vec[0] / (fp + tp);
_rmse = sqrt(global_sqrerr_vec[0] / (fp + tp));
_predicted_ctr = global_pred_vec[0] / (fp + tp);
} else {
_mae = _local_abserr / (fp + tp);
_rmse = sqrt(_local_sqrerr / (fp + tp));
_predicted_ctr = _local_pred / (fp + tp);
}
_actual_ctr = tp / (fp + tp);
_size = fp + tp;
calculate_bucket_error();
#endif
}
void BasicAucCalculator::calculate_bucket_error() {
#if defined(PADDLE_WITH_GLOO)
double last_ctr = -1;
double impression_sum = 0;
double ctr_sum = 0.0;
double click_sum = 0.0;
double error_sum = 0.0;
double error_count = 0;
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
if (gloo_wrapper->Size() > 1) {
auto neg_table = gloo_wrapper->AllReduce(_table[0], "sum");
auto pos_table = gloo_wrapper->AllReduce(_table[1], "sum");
for (int i = 0; i < _table_size; i++) {
double click = pos_table[i];
double show = neg_table[i] + pos_table[i];
double ctr = static_cast<double>(i) / _table_size;
if (fabs(ctr - last_ctr) > kMaxSpan) {
last_ctr = ctr;
impression_sum = 0.0;
ctr_sum = 0.0;
click_sum = 0.0;
}
impression_sum += show;
ctr_sum += ctr * show;
click_sum += click;
double adjust_ctr = ctr_sum / impression_sum;
double relative_error =
sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum));
if (relative_error < kRelativeErrorBound) {
double actual_ctr = click_sum / impression_sum;
double relative_ctr_error = fabs(actual_ctr / adjust_ctr - 1);
error_sum += relative_ctr_error * impression_sum;
error_count += impression_sum;
last_ctr = -1;
}
}
} else {
double* table[2] = {&_table[0][0], &_table[1][0]};
for (int i = 0; i < _table_size; i++) {
double click = table[1][i];
double show = table[0][i] + table[1][i];
double ctr = static_cast<double>(i) / _table_size;
if (fabs(ctr - last_ctr) > kMaxSpan) {
last_ctr = ctr;
impression_sum = 0.0;
ctr_sum = 0.0;
click_sum = 0.0;
}
impression_sum += show;
ctr_sum += ctr * show;
click_sum += click;
double adjust_ctr = ctr_sum / impression_sum;
double relative_error =
sqrt((1 - adjust_ctr) / (adjust_ctr * impression_sum));
if (relative_error < kRelativeErrorBound) {
double actual_ctr = click_sum / impression_sum;
double relative_ctr_error = fabs(actual_ctr / adjust_ctr - 1);
error_sum += relative_ctr_error * impression_sum;
error_count += impression_sum;
last_ctr = -1;
}
}
}
_bucket_error = error_count > 0 ? error_sum / error_count : 0.0;
#endif
}
void BasicAucCalculator::reset_records() {
// reset wuauc_records_
wuauc_records_.clear();
_user_cnt = 0;
_size = 0;
_uauc = 0;
_wuauc = 0;
}
// add uid data
void BasicAucCalculator::add_uid_data(const float* d_pred,
const int64_t* d_label,
const int64_t* d_uid, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
thread_local std::vector<uint64_t> h_uid;
h_pred.resize(batch_size);
h_label.resize(batch_size);
h_uid.resize(batch_size);
memcpy(h_pred.data(), d_pred, sizeof(float) * batch_size);
memcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size);
memcpy(h_uid.data(), d_uid, sizeof(uint64_t) * batch_size);
std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_uid_unlock_data(h_pred[i], h_label[i], static_cast<uint64_t>(h_uid[i]));
}
}
void BasicAucCalculator::add_uid_unlock_data(double pred, int label,
uint64_t uid) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
WuaucRecord record;
record.uid_ = uid;
record.label_ = label;
record.pred_ = pred;
wuauc_records_.emplace_back(std::move(record));
}
void BasicAucCalculator::computeWuAuc() {
std::sort(wuauc_records_.begin(), wuauc_records_.end(),
[](const WuaucRecord& lhs, const WuaucRecord& rhs) {
if (lhs.uid_ == rhs.uid_) {
if (lhs.pred_ == rhs.pred_) {
return lhs.label_ < rhs.label_;
} else {
return lhs.pred_ > rhs.pred_;
}
} else {
return lhs.uid_ > rhs.uid_;
}
});
WuaucRocData roc_data;
uint64_t prev_uid = 0;
size_t prev_pos = 0;
for (size_t i = 0; i < wuauc_records_.size(); ++i) {
if (wuauc_records_[i].uid_ != prev_uid) {
std::vector<WuaucRecord> single_user_recs(
wuauc_records_.begin() + prev_pos, wuauc_records_.begin() + i);
roc_data = computeSingelUserAuc(single_user_recs);
if (roc_data.auc_ != -1) {
double ins_num = (roc_data.tp_ + roc_data.fp_);
_user_cnt += 1;
_size += ins_num;
_uauc += roc_data.auc_;
_wuauc += roc_data.auc_ * ins_num;
}
prev_uid = wuauc_records_[i].uid_;
prev_pos = i;
}
}
std::vector<WuaucRecord> single_user_recs(wuauc_records_.begin() + prev_pos,
wuauc_records_.end());
roc_data = computeSingelUserAuc(single_user_recs);
if (roc_data.auc_ != -1) {
double ins_num = (roc_data.tp_ + roc_data.fp_);
_user_cnt += 1;
_size += ins_num;
_uauc += roc_data.auc_;
_wuauc += roc_data.auc_ * ins_num;
}
}
BasicAucCalculator::WuaucRocData BasicAucCalculator::computeSingelUserAuc(
const std::vector<WuaucRecord>& records) {
double tp = 0.0;
double fp = 0.0;
double newtp = 0.0;
double newfp = 0.0;
double area = 0.0;
double auc = -1;
size_t i = 0;
while (i < records.size()) {
newtp = tp;
newfp = fp;
if (records[i].label_ == 1) {
newtp += 1;
} else {
newfp += 1;
}
// check i+1
while (i < records.size() - 1 && records[i].pred_ == records[i + 1].pred_) {
if (records[i + 1].label_ == 1) {
newtp += 1;
} else {
newfp += 1;
}
i += 1;
}
area += (newfp - fp) * (tp + newtp) / 2.0;
tp = newtp;
fp = newfp;
i += 1;
}
if (tp > 0 && fp > 0) {
auc = area / (fp * tp + 1e-9);
} else {
auc = -1;
}
return {tp, fp, auc};
}
} // namespace framework
} // namespace paddle
此差异已折叠。
......@@ -71,6 +71,10 @@ ELSE()
set(STREAM_CALLBACK_DEPS)
ENDIF()
if(WITH_GLOO)
cc_library(gloo_context SRCS gloo_context.cc DEPS framework_proto gloo_wrapper enforce)
endif()
cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# memcpy depends on device_context, here add deps individually for
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/gloo_context.h"
namespace paddle {
namespace platform {
#if defined(PADDLE_WITH_GLOO)
void GlooParallelContext::Init() {
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
gloo_ptr->SetRank(strategy_.rank);
gloo_ptr->SetSize(strategy_.rank_num);
gloo_ptr->SetIface(strategy_.iface);
gloo_ptr->SetTimeoutSeconds(strategy_.init_seconds, strategy_.run_seconds);
// gloo_ptr->SetHttpStore(strategy_.ip_address, strategy_.ip_port,
// strategy_.scope);
gloo_ptr->SetHdfsStore(strategy_.hdfs_path, strategy_.hdfs_name,
strategy_.hdfs_ugi);
gloo_ptr->Init();
}
void GlooParallelContext::Barrier() {
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(gloo_ptr->IsInitialized(), true,
paddle::platform::errors::Unavailable(
"Gloo context is not initialized."));
gloo_ptr->Barrier();
}
void GlooParallelContext::ReleaseContext() {
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
if (gloo_ptr->IsInitialized() == true) {
gloo_ptr.reset();
}
}
#endif
} // namespace platform
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
namespace paddle {
namespace platform {
#if defined(PADDLE_WITH_GLOO)
struct GlooParallelStrategy {
int rank{0};
int rank_num{1};
std::string iface;
int init_seconds{9999999};
int run_seconds{9999999};
std::string hdfs_path;
std::string hdfs_name;
std::string hdfs_ugi;
std::string ip_address;
int ip_port;
std::string scope{"worker"};
};
class GlooParallelContext {
public:
explicit GlooParallelContext(const GlooParallelStrategy& strategy)
: strategy_(strategy) {}
virtual ~GlooParallelContext() {}
virtual void Init();
virtual void Barrier();
virtual void ReleaseContext();
protected:
GlooParallelStrategy strategy_;
};
#endif
} // namespace platform
} // namespace paddle
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper metrics prune
feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils)
gloo_wrapper gloo_context infer_io_utils)
if (WITH_NCCL)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper)
......@@ -32,7 +32,9 @@ set(PYBIND_SRCS
reader_py.cc
fleet_wrapper_py.cc
gloo_wrapper_py.cc
gloo_context_py.cc
box_helper_py.cc
metrics_py.cc
data_set_py.cc
imperative.cc
ir.cc
......
......@@ -253,6 +253,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_sid", &framework::Dataset::SetMergeBySid,
py::call_guard<py::gil_scoped_release>())
.def("set_shuffle_by_uid", &framework::Dataset::SetShuffleByUid,
py::call_guard<py::gil_scoped_release>())
.def("preprocess_instance", &framework::Dataset::PreprocessInstance,
py::call_guard<py::gil_scoped_release>())
.def("postprocess_instance", &framework::Dataset::PostprocessInstance,
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/gloo_context_py.h"
#include <Python.h>
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/platform/gloo_context.h"
namespace paddle {
namespace pybind {
namespace py = ::pybind11;
// Bind Methods
void BindGlooContext(py::module *m) {
// define parallel context for gloo
#if defined(PADDLE_WITH_GLOO)
py::class_<platform::GlooParallelStrategy> gloo_parallel_strategy(
*m, "GlooParallelStrategy", "");
gloo_parallel_strategy.def(py::init())
.def_property("rank_num",
[](const platform::GlooParallelStrategy &self) {
return self.rank_num;
},
[](platform::GlooParallelStrategy &self, int nranks) {
self.rank_num = nranks;
})
.def_property(
"rank",
[](const platform::GlooParallelStrategy &self) { return self.rank; },
[](platform::GlooParallelStrategy &self, int rank) {
self.rank = rank;
})
.def_property(
"iface",
[](const platform::GlooParallelStrategy &self) { return self.iface; },
[](platform::GlooParallelStrategy &self, const std::string &iface) {
self.iface = iface;
})
.def_property("init_seconds",
[](const platform::GlooParallelStrategy &self) {
return self.init_seconds;
},
[](platform::GlooParallelStrategy &self, int init_seconds) {
self.init_seconds = init_seconds;
})
.def_property("run_seconds",
[](const platform::GlooParallelStrategy &self) {
return self.run_seconds;
},
[](platform::GlooParallelStrategy &self, int run_seconds) {
self.run_seconds = run_seconds;
})
.def_property(
"ip_address",
[](const platform::GlooParallelStrategy &self) {
return self.ip_address;
},
[](platform::GlooParallelStrategy &self,
const std::string &ip_address) { self.ip_address = ip_address; })
.def_property("ip_port",
[](const platform::GlooParallelStrategy &self) {
return self.ip_port;
},
[](platform::GlooParallelStrategy &self, int ip_port) {
self.ip_port = ip_port;
})
.def_property("hdfs_path",
[](const platform::GlooParallelStrategy &self) {
return self.hdfs_path;
},
[](platform::GlooParallelStrategy &self, const std::string &hdfs_path) {
self.hdfs_path = hdfs_path;
})
.def_property("hdfs_name",
[](const platform::GlooParallelStrategy &self) {
return self.hdfs_name;
},
[](platform::GlooParallelStrategy &self, const std::string &hdfs_name) {
self.hdfs_name = hdfs_name;
})
.def_property("hdfs_ugi",
[](const platform::GlooParallelStrategy &self) {
return self.hdfs_ugi;
},
[](platform::GlooParallelStrategy &self, const std::string &hdfs_ugi) {
self.hdfs_ugi = hdfs_ugi;
});
py::class_<platform::GlooParallelContext> gloo_ctx(*m, "GlooParallelContext");
gloo_ctx.def(py::init<const platform::GlooParallelStrategy &>())
.def("init", [](platform::GlooParallelContext &self) { self.Init(); })
.def("barrier",
[](platform::GlooParallelContext &self) { self.Barrier(); })
.def("release",
[](platform::GlooParallelContext &self) { self.ReleaseContext(); });
#endif
}
} // namespace pybind
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace paddle {
namespace pybind {
void BindGlooContext(pybind11::module* m);
} // namespace pybind
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <memory>
#include <string>
#include <vector>
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/fleet/metrics.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/metrics_py.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindMetrics(py::module* m) {
py::class_<framework::Metric, std::shared_ptr<framework::Metric>>(*m,
"Metric")
.def(py::init([]() { return framework::Metric::SetInstance(); }))
.def("init_metric", &framework::Metric::InitMetric,
py::call_guard<py::gil_scoped_release>())
.def("flip_phase", &framework::Metric::FlipPhase,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_msg", &framework::Metric::GetMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_wuauc_metric_msg", &framework::Metric::GetWuAucMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_name_list", &framework::Metric::GetMetricNameList,
py::call_guard<py::gil_scoped_release>());
} // end Metrics
} // end namespace pybind
} // end namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindMetrics(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -59,11 +59,13 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/box_helper_py.h"
#include "paddle/fluid/pybind/metrics_py.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include "paddle/fluid/pybind/gloo_context_py.h"
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
......@@ -2414,11 +2416,15 @@ All parameter, weight, gradient are variables in Paddle.
BindFleetWrapper(&m);
BindGlooWrapper(&m);
BindBoxHelper(&m);
BindMetrics(&m);
#ifdef PADDLE_WITH_BOX_PS
BindBoxWrapper(&m);
#endif
#ifdef PADDLE_WITH_NCCL
BindNCCLWrapper(&m);
#endif
#ifdef PADDLE_WITH_GLOO
BindGlooContext(&m);
#endif
BindGraph(&m);
BindNode(&m);
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .metrics import init_metric # noqa: F401
from .metrics import print_auc # noqa: F401
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import yaml
import paddle.fluid as fluid
import logging
from paddle.distributed.utils import get_logger
__all__ = []
logger = get_logger(logging.INFO, name="metrics")
# read metric config from yaml and init MetricMsg in fleet_wrapper
def init_metric(metric_ptr,
metric_yaml_path,
cmatch_rank_var="",
mask_var="",
uid_var="",
phase=-1,
cmatch_rank_group="",
ignore_rank=False,
bucket_size=1000000):
yaml_fobj = open(metric_yaml_path)
if sys.version.startswith('2.7.13'):
content = yaml.load(yaml_fobj)
else:
content = yaml.load(yaml_fobj, Loader=yaml.FullLoader)
print("yaml metric config: \n")
print(content)
metric_runner_list = content['monitors']
if not metric_runner_list:
metric_runner_list = []
for metric_runner in metric_runner_list:
is_join = metric_runner['phase'] == 'JOINING'
phase = 1 if is_join else 0
if metric_runner['method'] == 'AucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, uid_var, phase, cmatch_rank_group,
ignore_rank, bucket_size)
elif metric_runner['method'] == 'MultiTaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], mask_var, uid_var, phase,
metric_runner['cmatch_group'], ignore_rank, bucket_size)
elif metric_runner['method'] == 'CmatchRankAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], mask_var, uid_var, phase,
metric_runner['cmatch_group'], metric_runner['ignore_rank'],
bucket_size)
elif metric_runner['method'] == 'MaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, metric_runner['mask'], uid_var, phase,
cmatch_rank_group, ignore_rank, bucket_size)
elif metric_runner['method'] == 'CmatchRankMaskAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
metric_runner['cmatch_var'], metric_runner['mask'], uid_var,
phase, metric_runner['cmatch_group'],
metric_runner['ignore_rank'], bucket_size)
elif metric_runner['method'] == 'WuAucCalculator':
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, metric_runner['uid'], phase,
cmatch_rank_group, ignore_rank, bucket_size)
else:
metric_ptr.init_metric(
metric_runner['method'], metric_runner['name'],
metric_runner['label'], metric_runner['target'],
cmatch_rank_var, mask_var, phase, cmatch_rank_group,
ignore_rank, bucket_size)
def print_metric(metric_ptr, name):
"""
print the metric value. Print directly in back-end
"""
if name.find("wuauc") != -1:
metric = metric_ptr.get_wuauc_metric_msg(name)
monitor_msg = "%s: User Count=%.0f INS Count=%.0f UAUC=%.6f WUAUC=%.6f "\
% (name, metric[0], metric[1], metric[4], metric[5])
else:
metric = metric_ptr.get_metric_msg(name)
monitor_msg = "%s: AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "\
"Actual CTR=%.6f Predicted CTR=%.6f COPC=%.6f INS Count=%.0f"\
% (name, metric[0], metric[1], metric[2], metric[3], metric[4],
metric[5], metric[6], metric[7])
# logger.info(monitor_msg)
return monitor_msg
def print_auc(metric_ptr, is_day, phase="all"):
"""
print metric according to stage and phase
"""
if is_day is True:
stage = "day"
stage_num = -1
else:
stage = "pass"
stage_num = 1 if phase == "join" else 0
metric_results = []
name_list = metric_ptr.get_metric_name_list(stage_num)
if phase == "all":
for name in name_list:
if name.find(stage) != -1:
metric_results.append(print_metric(metric_ptr, name=name))
else:
for name in name_list:
if name.find(stage) != -1 and name.find(phase) != -1:
metric_results.append(print_metric(metric_ptr, name=name))
return metric_results
......@@ -221,6 +221,23 @@ class DatasetBase(object):
self.dataset.set_filelist(filelist)
self.filelist = filelist
def set_uid_slot(self, uid_slot):
"""
Set user slot name.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_uid_slot('6048')
Args:
set_uid_slot(string): user slot name
"""
multi_slot = self.proto_desc.multi_slot_desc
multi_slot.uid_slot = uid_slot
def set_use_var(self, var_list):
"""
Set Variables which you will use.
......@@ -660,6 +677,23 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = True
self.parse_ins_id = True
def set_shuffle_by_uid(self, enable_shuffle_uid):
"""
Set if Dataset need to shuffle by uid.
Args:
set_shuffle_by_uid(bool): if shuffle according to uid or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_shuffle_by_uid(True)
"""
self.dataset.set_shuffle_by_uid(enable_shuffle_uid)
def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
self.gen_uni_feasigns = generate_uni_feasigns
......
......@@ -596,6 +596,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._hdfs_path = kwargs.get("path", "").rstrip("/")
self._init_timeout_seconds = kwargs.get("init_timeout_seconds", 3600)
self._run_timeout_seconds = kwargs.get("run_timeout_seconds", 9999999)
self._use_metric = kwargs.get("use_metric", False)
ip_port = kwargs.get("http_ip_port", "")
self._http_ip_port = []
self._http_server = None
......@@ -609,6 +610,7 @@ class GeneralRoleMaker(RoleMakerBase):
# set running status of http server
self._http_server_d["running"] = False
self._iface = self.__get_default_iface()
self._iface = "" if self._iface == "lo" else self._iface
# this environment variable can be empty
self._prefix = os.getenv("SYS_JOB_ID", "")
......@@ -670,6 +672,21 @@ class GeneralRoleMaker(RoleMakerBase):
self._hdfs_name, self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
if self._use_metric:
Gloo_strategy = fluid.core.GlooParallelStrategy()
Gloo_strategy.rank = current_id
Gloo_strategy.rank_num = len(worker_endpoints)
# Gloo_strategy.ip_address = self._http_ip_port[0]
# Gloo_strategy.ip_port = int(self._http_ip_port[1])
Gloo_strategy.hdfs_path = self._hdfs_path + "/trainer"
Gloo_strategy.hdfs_name = self._hdfs_name
Gloo_strategy.hdfs_ugi = self._hdfs_ugi
Default_init_timeout_seconds = 3600
Default_run_timeout_seconds = 9999999
Gloo_strategy.init_seconds = Default_init_timeout_seconds
Gloo_strategy.run_seconds = Default_run_timeout_seconds
Gloo = fluid.core.GlooParallelContext(Gloo_strategy)
Gloo.init()
else:
self._all_comm = MockBarrier()
elif training_role == "PSERVER":
......
......@@ -68,6 +68,14 @@ class TestDataset(unittest.TestCase):
self.assertTrue(dataset.parse_ins_id)
self.assertTrue(dataset.parse_content)
def test_shuffle_by_uid(self):
"""
Testcase for shuffle_by_uid.
"""
dataset = paddle.distributed.InMemoryDataset()
dataset._set_uid_slot('6048')
dataset._set_shuffle_by_uid(True)
def test_run_with_dump(self):
"""
Testcase for InMemoryDataset from create to run.
......
......@@ -140,6 +140,7 @@ packages=['paddle',
'paddle.dataset',
'paddle.reader',
'paddle.distributed',
'paddle.distributed.metric',
'paddle.complex',
'paddle.complex.tensor',
'paddle.fluid',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册