diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 88a8ed06207d9969b042bf9d8cd4d2b3b969042d..64f1b3b436c70f8a107706e0d242096ec70deb50 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -518,6 +518,7 @@ paddle.fluid.contrib.BasicLSTMUnit.state_dict (ArgSpec(args=['self', 'destinatio paddle.fluid.contrib.BasicLSTMUnit.sublayers (ArgSpec(args=['self', 'include_sublayers'], varargs=None, keywords=None, defaults=(True,)), ('document', '00a881005ecbc96578faf94513bf0d62')) paddle.fluid.contrib.BasicLSTMUnit.train (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.contrib.basic_lstm (ArgSpec(args=['input', 'init_hidden', 'init_cell', 'hidden_size', 'num_layers', 'sequence_length', 'dropout_prob', 'bidirectional', 'batch_first', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'forget_bias', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, 0.0, False, True, None, None, None, None, 1.0, 'float32', 'basic_lstm')), ('document', 'fe4d0c3c55a162b8cfe10b05fabb7ce4')) +paddle.fluid.contrib.ctr_metric_bundle (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', 'b68d12366896c41065fc3738393da2aa')) paddle.fluid.dygraph.Layer ('paddle.fluid.dygraph.layers.Layer', ('document', 'a889d5affd734ede273e94d4257163ab')) paddle.fluid.dygraph.Layer.__init__ (ArgSpec(args=['self', 'name_scope', 'dtype'], varargs=None, keywords=None, defaults=(VarType.FP32,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.dygraph.Layer.add_parameter (ArgSpec(args=['self', 'name', 'parameter'], varargs=None, keywords=None, defaults=None), ('document', 'f35ab374c7d5165c3daf3bd64a5a2ec1')) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 627b7790e87a3c2ba611460f7c8b72c10f94cfaf..f33b4668f01c139c6c8ba7f93c05f7fa3322633f 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -33,11 +33,53 @@ limitations under the License. */ #include "io/shell.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/platform/timer.h" namespace paddle { namespace framework { +void RecordCandidateList::ReSize(size_t length) { + _mutex.lock(); + _capacity = length; + CHECK(_capacity > 0); // NOLINT + _candidate_list.clear(); + _candidate_list.resize(_capacity); + _full = false; + _cur_size = 0; + _total_size = 0; + _mutex.unlock(); +} + +void RecordCandidateList::ReInit() { + _mutex.lock(); + _full = false; + _cur_size = 0; + _total_size = 0; + _mutex.unlock(); +} + +void RecordCandidateList::AddAndGet(const Record& record, + RecordCandidate* result) { + _mutex.lock(); + size_t index = 0; + ++_total_size; + auto fleet_ptr = FleetWrapper::GetInstance(); + if (!_full) { + _candidate_list[_cur_size++] = record; + _full = (_cur_size == _capacity); + } else { + CHECK(_cur_size == _capacity); + index = fleet_ptr->LocalRandomEngine()() % _total_size; + if (index < _capacity) { + _candidate_list[index] = record; + } + } + index = fleet_ptr->LocalRandomEngine()() % _cur_size; + *result = _candidate_list[index]; + _mutex.unlock(); +} + void DataFeed::AddFeedVar(Variable* var, const std::string& name) { CheckInit(); for (size_t i = 0; i < use_slots_.size(); ++i) { diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 400212f1d278705782c1cab680ebb5843601351c..5b314905143350b9d547fe08116703ff92dd4203 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -26,6 +26,7 @@ limitations under the License. */ #include #include #include // NOLINT +#include #include #include @@ -427,6 +428,41 @@ struct Record { std::string ins_id_; }; +struct RecordCandidate { + std::string ins_id_; + std::unordered_multimap feas; + + RecordCandidate& operator=(const Record& rec) { + feas.clear(); + ins_id_ = rec.ins_id_; + for (auto& fea : rec.uint64_feasigns_) { + feas.insert({fea.slot(), fea.sign()}); + } + return *this; + } +}; + +class RecordCandidateList { + public: + RecordCandidateList() = default; + RecordCandidateList(const RecordCandidateList&) = delete; + RecordCandidateList& operator=(const RecordCandidateList&) = delete; + + void ReSize(size_t length); + + void ReInit(); + + void AddAndGet(const Record& record, RecordCandidate* result); + + private: + size_t _capacity = 0; + std::mutex _mutex; + bool _full = false; + size_t _cur_size = 0; + size_t _total_size = 0; + std::vector _candidate_list; +}; + template paddle::framework::Archive& operator<<(paddle::framework::Archive& ar, const FeatureKey& fk) { diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 114496085429b18607ff84178e181c36bd2d1adb..a7e12cb817b0848e2b72e9457918af920df01964 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -114,6 +114,14 @@ void DatasetImpl::SetMergeByInsId( keep_unmerged_ins_ = keep_unmerged_ins; } +template +void DatasetImpl::SetFeaEval(bool fea_eval, int record_candidate_size) { + slots_shuffle_fea_eval_ = fea_eval; + slots_shuffle_rclist_.ReSize(record_candidate_size); + VLOG(3) << "SetFeaEval fea eval mode: " << fea_eval + << " with record candidate size: " << record_candidate_size; +} + template std::vector DatasetImpl::GetReaders() { std::vector ret; @@ -646,5 +654,167 @@ void MultiSlotDataset::MergeByInsId() { VLOG(3) << "MultiSlotDataset::MergeByInsId end"; } +void MultiSlotDataset::GetRandomData(const std::set& slots_to_replace, + std::vector* result) { + int debug_erase_cnt = 0; + int debug_push_cnt = 0; + auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); + slots_shuffle_rclist_.ReInit(); + for (const auto& rec : slots_shuffle_original_data_) { + RecordCandidate rand_rec; + Record new_rec = rec; + slots_shuffle_rclist_.AddAndGet(rec, &rand_rec); + for (auto it = new_rec.uint64_feasigns_.begin(); + it != new_rec.uint64_feasigns_.end();) { + if (slots_to_replace.find(it->slot()) != slots_to_replace.end()) { + it = new_rec.uint64_feasigns_.erase(it); + debug_erase_cnt += 1; + } else { + ++it; + } + } + for (auto slot : slots_to_replace) { + auto range = rand_rec.feas.equal_range(slot); + for (auto it = range.first; it != range.second; ++it) { + new_rec.uint64_feasigns_.push_back({it->second, it->first}); + debug_push_cnt += 1; + } + } + result->push_back(std::move(new_rec)); + } + VLOG(2) << "erase feasign num: " << debug_erase_cnt + << " repush feasign num: " << debug_push_cnt; +} + +// slots shuffle to input_channel_ with needed-shuffle slots +void MultiSlotDataset::SlotsShuffle( + const std::set& slots_to_replace) { + int out_channel_size = 0; + if (cur_channel_ == 0) { + for (size_t i = 0; i < multi_output_channel_.size(); ++i) { + out_channel_size += multi_output_channel_[i]->Size(); + } + } else { + for (size_t i = 0; i < multi_consume_channel_.size(); ++i) { + out_channel_size += multi_consume_channel_[i]->Size(); + } + } + VLOG(2) << "DatasetImpl::SlotsShuffle() begin with input channel size: " + << input_channel_->Size() + << " output channel size: " << out_channel_size; + if (!slots_shuffle_fea_eval_) { + VLOG(3) << "DatasetImpl::SlotsShuffle() end," + "fea eval mode off, need to set on for slots shuffle"; + return; + } + if ((!input_channel_ || input_channel_->Size() == 0) && + slots_shuffle_original_data_.size() == 0 && out_channel_size == 0) { + VLOG(3) << "DatasetImpl::SlotsShuffle() end, no data to slots shuffle"; + return; + } + platform::Timer timeline; + timeline.Start(); + auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); + std::set index_slots; + for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) { + std::string cur_slot = multi_slot_desc.slots(i).name(); + if (slots_to_replace.find(cur_slot) != slots_to_replace.end()) { + index_slots.insert(i); + } + } + if (slots_shuffle_original_data_.size() == 0) { + // before first slots shuffle, instances could be in + // input_channel, oupput_channel or consume_channel + if (input_channel_ && input_channel_->Size() != 0) { + slots_shuffle_original_data_.reserve(input_channel_->Size()); + input_channel_->Close(); + input_channel_->ReadAll(slots_shuffle_original_data_); + } else { + CHECK(out_channel_size > 0); // NOLINT + if (cur_channel_ == 0) { + for (size_t i = 0; i < multi_output_channel_.size(); ++i) { + std::vector vec_data; + multi_output_channel_[i]->Close(); + multi_output_channel_[i]->ReadAll(vec_data); + slots_shuffle_original_data_.reserve( + slots_shuffle_original_data_.size() + vec_data.size()); + slots_shuffle_original_data_.insert( + slots_shuffle_original_data_.end(), + std::make_move_iterator(vec_data.begin()), + std::make_move_iterator(vec_data.end())); + vec_data.clear(); + vec_data.shrink_to_fit(); + multi_output_channel_[i]->Clear(); + } + } else { + for (size_t i = 0; i < multi_consume_channel_.size(); ++i) { + std::vector vec_data; + multi_consume_channel_[i]->Close(); + multi_consume_channel_[i]->ReadAll(vec_data); + slots_shuffle_original_data_.reserve( + slots_shuffle_original_data_.size() + vec_data.size()); + slots_shuffle_original_data_.insert( + slots_shuffle_original_data_.end(), + std::make_move_iterator(vec_data.begin()), + std::make_move_iterator(vec_data.end())); + vec_data.clear(); + vec_data.shrink_to_fit(); + multi_consume_channel_[i]->Clear(); + } + } + } + } else { + // if already have original data for slots shuffle, clear channel + input_channel_->Clear(); + if (cur_channel_ == 0) { + for (size_t i = 0; i < multi_output_channel_.size(); ++i) { + if (!multi_output_channel_[i]) { + continue; + } + multi_output_channel_[i]->Clear(); + } + } else { + for (size_t i = 0; i < multi_consume_channel_.size(); ++i) { + if (!multi_consume_channel_[i]) { + continue; + } + multi_consume_channel_[i]->Clear(); + } + } + } + int end_size = 0; + if (cur_channel_ == 0) { + for (size_t i = 0; i < multi_output_channel_.size(); ++i) { + if (!multi_output_channel_[i]) { + continue; + } + end_size += multi_output_channel_[i]->Size(); + } + } else { + for (size_t i = 0; i < multi_consume_channel_.size(); ++i) { + if (!multi_consume_channel_[i]) { + continue; + } + end_size += multi_consume_channel_[i]->Size(); + } + } + CHECK(input_channel_->Size() == 0) + << "input channel should be empty before slots shuffle"; + std::vector random_data; + random_data.clear(); + // get slots shuffled random_data + GetRandomData(index_slots, &random_data); + input_channel_->Open(); + input_channel_->Write(std::move(random_data)); + random_data.clear(); + random_data.shrink_to_fit(); + input_channel_->Close(); + + timeline.Pause(); + VLOG(2) << "DatasetImpl::SlotsShuffle() end" + << ", memory data size for slots shuffle=" << input_channel_->Size() + << ", cost time=" << timeline.ElapsedSec() << " seconds"; +} + } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 3c40a7c0cecc0b1bbb51aebcb900da2f52602e0f..7b725a6f2739c439e8edbeba498c1ea77c840af0 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -17,6 +17,7 @@ #include #include #include // NOLINT +#include #include #include // NOLINT #include @@ -61,6 +62,8 @@ class Dataset { virtual void SetMergeByInsId(const std::vector& merge_slot_list, bool erase_duplicate_feas, int min_merge_size, bool keep_unmerged_ins) = 0; + // set fea eval mode + virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0; // get file list virtual const std::vector& GetFileList() = 0; // get thread num @@ -94,6 +97,10 @@ class Dataset { virtual void LocalShuffle() = 0; // global shuffle data virtual void GlobalShuffle() = 0; + // for slots shuffle + virtual void SlotsShuffle(const std::set& slots_to_replace) = 0; + virtual void GetRandomData(const std::set& slots_to_replace, + std::vector* result) = 0; // create readers virtual void CreateReaders() = 0; // destroy readers @@ -130,6 +137,7 @@ class DatasetImpl : public Dataset { bool erase_duplicate_feas, int min_merge_size, bool keep_unmerged_ins); + virtual void SetFeaEval(bool fea_eval, int record_candidate_size); virtual const std::vector& GetFileList() { return filelist_; } virtual int GetThreadNum() { return thread_num_; } virtual int GetTrainerNum() { return trainer_num_; } @@ -150,6 +158,9 @@ class DatasetImpl : public Dataset { virtual void ReleaseMemory(); virtual void LocalShuffle(); virtual void GlobalShuffle(); + virtual void SlotsShuffle(const std::set& slots_to_replace) {} + virtual void GetRandomData(const std::set& slots_to_replace, + std::vector* result) {} virtual void CreateReaders(); virtual void DestroyReaders(); virtual int64_t GetMemoryDataSize(); @@ -168,6 +179,8 @@ class DatasetImpl : public Dataset { // and when finish reading, we set cur_channel = 1 - cur_channel, // so if cur_channel=0, all data are in output_channel, else consume_channel int cur_channel_; + std::vector slots_shuffle_original_data_; + RecordCandidateList slots_shuffle_rclist_; int thread_num_; paddle::framework::DataFeedDesc data_feed_desc_; int trainer_num_; @@ -184,6 +197,7 @@ class DatasetImpl : public Dataset { bool keep_unmerged_ins_; int min_merge_size_; std::vector merge_slots_list_; + bool slots_shuffle_fea_eval_ = false; }; // use std::vector or Record as data type @@ -191,6 +205,9 @@ class MultiSlotDataset : public DatasetImpl { public: MultiSlotDataset() {} virtual void MergeByInsId(); + virtual void SlotsShuffle(const std::set& slots_to_replace); + virtual void GetRandomData(const std::set& slots_to_replace, + std::vector* result); virtual ~MultiSlotDataset() {} }; diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 3f4f345912467881ba0e83650c9ba1ee9aeee7b7..8d28e1cabfadc2912846b824c1945bcd39763ad1 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -512,6 +512,57 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { #endif } +double FleetWrapper::GetCacheThreshold() { +#ifdef PADDLE_WITH_PSLIB + double cache_threshold = 0.0; + auto ret = pslib_ptr_->_worker_ptr->flush(); + ret.wait(); + ret = pslib_ptr_->_worker_ptr->get_cache_threshold(0, cache_threshold); + ret.wait(); + if (cache_threshold < 0) { + LOG(ERROR) << "get cache threshold failed"; + exit(-1); + } + return cache_threshold; +#else + VLOG(0) << "FleetWrapper::GetCacheThreshold does nothing when no pslib"; + return 0.0; +#endif +} + +void FleetWrapper::CacheShuffle(int table_id, const std::string& path, + const int mode, const double cache_threshold) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->cache_shuffle( + 0, path, std::to_string(mode), std::to_string(cache_threshold)); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "cache shuffle failed"; + exit(-1); + } +#else + VLOG(0) << "FleetWrapper::CacheShuffle does nothing when no pslib"; +#endif +} + +int32_t FleetWrapper::SaveCache(int table_id, const std::string& path, + const int mode) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->save_cache(0, path, std::to_string(mode)); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "table save cache failed"; + exit(-1); + } + return feasign_cnt; +#else + VLOG(0) << "FleetWrapper::SaveCache does nothing when no pslib"; + return -1; +#endif +} + void FleetWrapper::ShrinkSparseTable(int table_id) { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->shrink(table_id); diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 17b58e575950edc61fd1ae6ba982f47ce15b03f6..e0456906d3c2494a5df5accef9aebe2ff9eed156 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -148,7 +148,13 @@ class FleetWrapper { // mode = 1, save delta feature, which means save diff void SaveModel(const std::string& path, const int mode); + double GetCacheThreshold(); + void CacheShuffle(int table_id, const std::string& path, const int mode, + const double cache_threshold); + int32_t SaveCache(int table_id, const std::string& path, const int mode); + void ClearModel(); + void ShrinkSparseTable(int table_id); void ShrinkDenseTable(int table_id, Scope* scope, std::vector var_list, float decay, diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 0e88027ea906dd560422531e77604aa7f5e3abb6..3e787822ecd642b94c80e016415c0286ea4d5926 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -103,6 +103,10 @@ void BindDataset(py::module* m) { .def("set_merge_by_lineid", &framework::Dataset::SetMergeByInsId, py::call_guard()) .def("merge_by_lineid", &framework::Dataset::MergeByInsId, + py::call_guard()) + .def("slots_shuffle", &framework::Dataset::SlotsShuffle, + py::call_guard()) + .def("set_fea_eval", &framework::Dataset::SetFeaEval, py::call_guard()); } diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 92cda13481fccc02feed4261d4ab2584dd0e6682..90772b3546c9dfdcc94fc85ca2b804365f03c021 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -49,6 +49,9 @@ void BindFleetWrapper(py::module* m) { .def("init_worker", &framework::FleetWrapper::InitWorker) .def("init_model", &framework::FleetWrapper::PushDenseParamSync) .def("save_model", &framework::FleetWrapper::SaveModel) + .def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold) + .def("cache_shuffle", &framework::FleetWrapper::CacheShuffle) + .def("save_cache", &framework::FleetWrapper::SaveCache) .def("load_model", &framework::FleetWrapper::LoadModel) .def("clear_model", &framework::FleetWrapper::ClearModel) .def("stop_server", &framework::FleetWrapper::StopServer) diff --git a/python/paddle/fluid/contrib/layers/__init__.py b/python/paddle/fluid/contrib/layers/__init__.py index 6ba971b527cf7a5dddd450652b246847cc8437a5..94889a65b3620f730dcd39c911599f50acbfe614 100644 --- a/python/paddle/fluid/contrib/layers/__init__.py +++ b/python/paddle/fluid/contrib/layers/__init__.py @@ -16,8 +16,12 @@ from __future__ import print_function from . import nn from .nn import * + from .rnn_impl import * +from . import metric_op +from .metric_op import * __all__ = [] __all__ += nn.__all__ __all__ += rnn_impl.__all__ +__all__ += metric_op.__all__ diff --git a/python/paddle/fluid/contrib/layers/metric_op.py b/python/paddle/fluid/contrib/layers/metric_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f76a3283f2f81880fce5cd8b8fa4fc46434fd165 --- /dev/null +++ b/python/paddle/fluid/contrib/layers/metric_op.py @@ -0,0 +1,188 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contrib layers just related to metric. +""" + +from __future__ import print_function + +import warnings +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.initializer import Normal, Constant +from paddle.fluid.framework import Variable +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layers import nn + +__all__ = ['ctr_metric_bundle'] + + +def ctr_metric_bundle(input, label): + """ + ctr related metric layer + + This function help compute the ctr related metrics: RMSE, MAE, predicted_ctr, q_value. + To compute the final values of these metrics, we should do following computations using + total instance number: + MAE = local_abserr / instance number + RMSE = sqrt(local_sqrerr / instance number) + predicted_ctr = local_prob / instance number + q = local_q / instance number + Note that if you are doing distribute job, you should all reduce these metrics and instance + number first + + Args: + input(Variable): A floating-point 2D Variable, values are in the range + [0, 1]. Each row is sorted in descending order. This + input should be the output of topk. Typically, this + Variable indicates the probability of each label. + label(Variable): A 2D int Variable indicating the label of the training + data. The height is batch size and width is always 1. + + Returns: + local_sqrerr(Variable): Local sum of squared error + local_abserr(Variable): Local sum of abs error + local_prob(Variable): Local sum of predicted ctr + local_q(Variable): Local sum of q value + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32") + label = fluid.layers.data(name="label", shape=[1], dtype="int32") + predict = fluid.layers.sigmoid(fluid.layers.fc(input=data, size=1)) + auc_out = fluid.contrib.layers.ctr_metric_bundle(input=predict, label=label) + """ + assert input.shape == label.shape + helper = LayerHelper("ctr_metric_bundle", **locals()) + + local_abserr = helper.create_global_variable( + persistable=True, dtype='float32', shape=[1]) + local_sqrerr = helper.create_global_variable( + persistable=True, dtype='float32', shape=[1]) + local_prob = helper.create_global_variable( + persistable=True, dtype='float32', shape=[1]) + local_q = helper.create_global_variable( + persistable=True, dtype='float32', shape=[1]) + local_pos_num = helper.create_global_variable( + persistable=True, dtype='float32', shape=[1]) + local_ins_num = helper.create_global_variable( + persistable=True, dtype='float32', shape=[1]) + + tmp_res_elesub = helper.create_global_variable( + persistable=False, dtype='float32', shape=[-1]) + tmp_res_sigmoid = helper.create_global_variable( + persistable=False, dtype='float32', shape=[-1]) + tmp_ones = helper.create_global_variable( + persistable=False, dtype='float32', shape=[-1]) + + batch_prob = helper.create_global_variable( + persistable=False, dtype='float32', shape=[1]) + batch_abserr = helper.create_global_variable( + persistable=False, dtype='float32', shape=[1]) + batch_sqrerr = helper.create_global_variable( + persistable=False, dtype='float32', shape=[1]) + batch_q = helper.create_global_variable( + persistable=False, dtype='float32', shape=[1]) + batch_pos_num = helper.create_global_variable( + persistable=False, dtype='float32', shape=[1]) + batch_ins_num = helper.create_global_variable( + persistable=False, dtype='float32', shape=[1]) + for var in [ + local_abserr, batch_abserr, local_sqrerr, batch_sqrerr, local_prob, + batch_prob, local_q, batch_q, batch_pos_num, batch_ins_num, + local_pos_num, local_ins_num + ]: + helper.set_variable_initializer( + var, Constant( + value=0.0, force_cpu=True)) + + helper.append_op( + type="elementwise_sub", + inputs={"X": [input], + "Y": [label]}, + outputs={"Out": [tmp_res_elesub]}) + + helper.append_op( + type="squared_l2_norm", + inputs={"X": [tmp_res_elesub]}, + outputs={"Out": [batch_sqrerr]}) + helper.append_op( + type="elementwise_add", + inputs={"X": [batch_sqrerr], + "Y": [local_sqrerr]}, + outputs={"Out": [local_sqrerr]}) + + helper.append_op( + type="l1_norm", + inputs={"X": [tmp_res_elesub]}, + outputs={"Out": [batch_abserr]}) + helper.append_op( + type="elementwise_add", + inputs={"X": [batch_abserr], + "Y": [local_abserr]}, + outputs={"Out": [local_abserr]}) + + helper.append_op( + type="reduce_sum", inputs={"X": [input]}, + outputs={"Out": [batch_prob]}) + helper.append_op( + type="elementwise_add", + inputs={"X": [batch_prob], + "Y": [local_prob]}, + outputs={"Out": [local_prob]}) + helper.append_op( + type="sigmoid", + inputs={"X": [input]}, + outputs={"Out": [tmp_res_sigmoid]}) + helper.append_op( + type="reduce_sum", + inputs={"X": [tmp_res_sigmoid]}, + outputs={"Out": [batch_q]}) + helper.append_op( + type="elementwise_add", + inputs={"X": [batch_q], + "Y": [local_q]}, + outputs={"Out": [local_q]}) + + helper.append_op( + type="reduce_sum", + inputs={"X": [label]}, + outputs={"Out": [batch_pos_num]}) + helper.append_op( + type="elementwise_add", + inputs={"X": [batch_pos_num], + "Y": [local_pos_num]}, + outputs={"Out": [local_pos_num]}) + + helper.append_op( + type='fill_constant_batch_size_like', + inputs={"Input": label}, + outputs={'Out': [tmp_ones]}, + attrs={ + 'shape': [-1, 1], + 'dtype': tmp_ones.dtype, + 'value': float(1.0), + }) + helper.append_op( + type="reduce_sum", + inputs={"X": [tmp_ones]}, + outputs={"Out": [batch_ins_num]}) + helper.append_op( + type="elementwise_add", + inputs={"X": [batch_ins_num], + "Y": [local_ins_num]}, + outputs={"Out": [local_ins_num]}) + + return local_sqrerr, local_abserr, local_prob, local_q, local_pos_num, local_ins_num diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 20ffd13d605779a5298efd10d947cf55868905f3..d4c8a32d6cf54768521d57e9416b1694054c3f25 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -91,6 +91,51 @@ class DatasetBase(object): """ self.proto_desc.pipe_command = pipe_command + def set_fea_eval(self, record_candidate_size, fea_eval=True): + """ + set fea eval mode for slots shuffle to debug the importance level of + slots(features), fea_eval need to be set True for slots shuffle. + + Args: + record_candidate_size(int): size of instances candidate to shuffle + one slot + fea_eval(bool): wheather enable fea eval mode to enable slots shuffle. + default is True. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_fea_eval(1000000, True) + + """ + if fea_eval: + self.dataset.set_fea_eval(fea_eval, record_candidate_size) + self.fea_eval = fea_eval + + def slots_shuffle(self, slots): + """ + Slots Shuffle + Slots Shuffle is a shuffle method in slots level, which is usually used + in sparse feature with large scale of instances. To compare the metric, i.e. + auc while doing slots shuffle on one or several slots with baseline to + evaluate the importance level of slots(features). + + Args: + slots(list[string]): the set of slots(string) to do slots shuffle. + + Examples: + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_merge_by_lineid() + #suppose there is a slot 0 + dataset.slots_shuffle(['0']) + """ + if self.fea_eval: + slots_set = set(slots) + self.dataset.slots_shuffle(slots_set) + def set_batch_size(self, batch_size): """ Set batch size. Will be effective during training diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index ac56142245b6ab3b4d94546c0abce7bc9f6f0971..b4993dae9dba6c439d305b3e223caa9391d4af07 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -15,7 +15,6 @@ import os import sys from optimizer_factory import * from google.protobuf import text_format - import paddle.fluid as fluid from paddle.fluid.framework import Program @@ -212,6 +211,45 @@ class PSLib(Fleet): self._fleet_ptr.save_model(dirname, mode) self._role_maker._barrier_worker() + def save_cache_model(self, executor, dirname, main_program=None, **kwargs): + """ + save sparse cache table, + when using fleet, it will save sparse cache table + + Args: + dirname(str): save path. It can be hdfs/afs path or local path + main_program(Program): fluid program, default None + kwargs: use define property, current support following + mode(int): define for feature extension in the future, + currently no use, will pass a default value 0 + + Example: + .. code-block:: python + >>> fleet.save_cache_model(None, dirname="/you/path/to/model", mode = 0) + + """ + mode = kwargs.get("mode", 0) + self._fleet_ptr.client_flush() + self._role_maker._barrier_worker() + cache_threshold = 0.0 + + if self._role_maker.is_first_worker(): + cache_threshold = self._fleet_ptr.get_cache_threshold() + #check cache threshold right or not + self._role_maker._barrier_worker() + + if self._role_maker.is_first_worker(): + self._fleet_ptr.cache_shuffle(0, dirname, mode, cache_threshold) + + self._role_maker._barrier_worker() + + feasign_num = -1 + if self._role_maker.is_first_worker(): + feasign_num = self._fleet_ptr.save_cache(0, dirname, mode) + + self._role_maker._barrier_worker() + return feasign_num + def shrink_sparse_table(self): """ shrink cvm of all sparse embedding in pserver, the decay rate diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py index 795fe79ca866890986ddcfd0816eb1a4c909fecd..888a5fa7a79de234000e25eb690e0bab8302a302 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='paddle', syntax='proto2', serialized_pb=_b( - '\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc4\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xf0\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\x12 \n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01' + '\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc0\x02\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\'\n\x19\x65nable_sparse_table_cache\x18\x07 \x01(\x08:\x04true\x12(\n\x17sparse_table_cache_rate\x18\x08 \x01(\x01:\x07\x30.00055\x12\'\n\x1bsparse_table_cache_file_num\x18\t \x01(\r:\x02\x31\x36\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xf0\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\x12 \n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\x9c\x03\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x12\x1b\n\x17PS_SAVE_ONE_CACHE_TABLE\x10\r\x12\x1a\n\x16PS_GET_CACHE_THRESHOLD\x10\x0e\x12\x14\n\x10PS_CACHE_SHUFFLE\x10\x0f\x12\x0e\n\nPS_S2S_MSG\x10\x65\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01' )) _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -49,8 +49,8 @@ _TABLETYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=3528, - serialized_end=3580, ) + serialized_start=3652, + serialized_end=3704, ) _sym_db.RegisterEnumDescriptor(_TABLETYPE) TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE) @@ -131,11 +131,31 @@ _PSCMDID = _descriptor.EnumDescriptor( _descriptor.EnumValueDescriptor( name='PS_STOP_SERVER', index=12, number=12, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='PS_SAVE_ONE_CACHE_TABLE', + index=13, + number=13, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PS_GET_CACHE_THRESHOLD', + index=14, + number=14, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PS_CACHE_SHUFFLE', + index=15, + number=15, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PS_S2S_MSG', index=16, number=101, options=None, type=None), ], containing_type=None, options=None, - serialized_start=3583, - serialized_end=3900, ) + serialized_start=3707, + serialized_end=4119, ) _sym_db.RegisterEnumDescriptor(_PSCMDID) PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID) @@ -154,6 +174,10 @@ PS_CLEAR_ONE_TABLE = 9 PS_CLEAR_ALL_TABLE = 10 PS_PUSH_DENSE_PARAM = 11 PS_STOP_SERVER = 12 +PS_SAVE_ONE_CACHE_TABLE = 13 +PS_GET_CACHE_THRESHOLD = 14 +PS_CACHE_SHUFFLE = 15 +PS_S2S_MSG = 101 _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor( name='FsApiType', @@ -168,8 +192,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=3496, - serialized_end=3526, ) + serialized_start=3620, + serialized_end=3650, ) _sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE) _PSPARAMETER = _descriptor.Descriptor( @@ -1057,6 +1081,54 @@ _TABLEPARAMETER = _descriptor.Descriptor( is_extension=False, extension_scope=None, options=None), + _descriptor.FieldDescriptor( + name='enable_sparse_table_cache', + full_name='paddle.TableParameter.enable_sparse_table_cache', + index=6, + number=7, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='sparse_table_cache_rate', + full_name='paddle.TableParameter.sparse_table_cache_rate', + index=7, + number=8, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(0.00055), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='sparse_table_cache_file_num', + full_name='paddle.TableParameter.sparse_table_cache_file_num', + index=8, + number=9, + type=13, + cpp_type=3, + label=1, + has_default_value=True, + default_value=16, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), ], extensions=[], nested_types=[], @@ -1067,7 +1139,7 @@ _TABLEPARAMETER = _descriptor.Descriptor( extension_ranges=[], oneofs=[], serialized_start=1573, - serialized_end=1769, ) + serialized_end=1893, ) _TABLEACCESSORPARAMETER = _descriptor.Descriptor( name='TableAccessorParameter', @@ -1213,8 +1285,8 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=1772, - serialized_end=2141, ) + serialized_start=1896, + serialized_end=2265, ) _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( name='DownpourTableAccessorParameter', @@ -1360,8 +1432,8 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2144, - serialized_end=2384, ) + serialized_start=2268, + serialized_end=2508, ) _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( name='TableAccessorSaveParameter', @@ -1427,8 +1499,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2386, - serialized_end=2469, ) + serialized_start=2510, + serialized_end=2593, ) _PSREQUESTMESSAGE = _descriptor.Descriptor( name='PsRequestMessage', @@ -1526,8 +1598,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2471, - serialized_end=2572, ) + serialized_start=2595, + serialized_end=2696, ) _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( name='SparseSGDRuleParameter', @@ -1609,8 +1681,8 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2574, - serialized_end=2693, ) + serialized_start=2698, + serialized_end=2817, ) _DENSESGDRULEPARAMETER = _descriptor.Descriptor( name='DenseSGDRuleParameter', @@ -1708,8 +1780,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2696, - serialized_end=2921, ) + serialized_start=2820, + serialized_end=3045, ) _ADAMSGDPARAMETER = _descriptor.Descriptor( name='AdamSGDParameter', @@ -1807,8 +1879,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2924, - serialized_end=3058, ) + serialized_start=3048, + serialized_end=3182, ) _NAIVESGDPARAMETER = _descriptor.Descriptor( name='NaiveSGDParameter', @@ -1858,8 +1930,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3060, - serialized_end=3126, ) + serialized_start=3184, + serialized_end=3250, ) _SUMMARYSGDPARAMETER = _descriptor.Descriptor( name='SummarySGDParameter', @@ -1893,8 +1965,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3128, - serialized_end=3187, ) + serialized_start=3252, + serialized_end=3311, ) _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( name='MovingAverageRuleParameter', @@ -1928,8 +2000,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3189, - serialized_end=3235, ) + serialized_start=3313, + serialized_end=3359, ) _PSRESPONSEMESSAGE = _descriptor.Descriptor( name='PsResponseMessage', @@ -1995,8 +2067,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3237, - serialized_end=3310, ) + serialized_start=3361, + serialized_end=3434, ) _FSCLIENTPARAMETER = _descriptor.Descriptor( name='FsClientParameter', @@ -2126,8 +2198,8 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3313, - serialized_end=3526, ) + serialized_start=3437, + serialized_end=3650, ) _PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER _PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index cd12cd23073a9bb96bd4f06ca84d57f2f303987b..5575f73883d7b58da02209ef1a6ae8a7e40a3c14 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -109,6 +109,8 @@ class TestDataset(unittest.TestCase): dataset.set_pipe_command("cat") dataset.set_use_var(slots_vars) dataset.load_into_memory() + dataset.set_fea_eval(10000, True) + dataset.slots_shuffle(["slot1"]) dataset.local_shuffle() exe = fluid.Executor(fluid.CPUPlace())