未验证 提交 9150cf50 编写于 作者: Y yaoxuefeng 提交者: GitHub

add save cache model api in fleet& add slots shuffle in dataset module & add...

add save cache model api in fleet& add slots shuffle in dataset module & add metric op to calculate ctr related metrics (#18871)

* add ctr related metric layer test=develop

* add save cache and slots shuffle test=develop

* add save cache and slots shuffle test=develop

* fix error

* fix error

* fix style for ci

* fix for comments

* change SlotsShuffle input to std::strinf for generality

* fix style

* fix style

* fix style

* fix style

* fix style

* fix style

* fix stylr

* fix style

* fix style

* fix style

* fix style

* fix style

* fix style

* fix style

* fix style

* fix style

* fix style

* fix style

* fix style

* fix style

* change non-const reference to pointer

* fix style

* fix style

* fix style test=develop

* fix style  test=develop

* add return ins num in ctr metric op

* change dtype to float in metric_op.py

* fix error test=develop

* fix style test=develop

* fix API spec

* fix API spec

* fix API spec test=develop

* add UT test=develop
上级 b7b584b0
......@@ -518,6 +518,7 @@ paddle.fluid.contrib.BasicLSTMUnit.state_dict (ArgSpec(args=['self', 'destinatio
paddle.fluid.contrib.BasicLSTMUnit.sublayers (ArgSpec(args=['self', 'include_sublayers'], varargs=None, keywords=None, defaults=(True,)), ('document', '00a881005ecbc96578faf94513bf0d62'))
paddle.fluid.contrib.BasicLSTMUnit.train (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.basic_lstm (ArgSpec(args=['input', 'init_hidden', 'init_cell', 'hidden_size', 'num_layers', 'sequence_length', 'dropout_prob', 'bidirectional', 'batch_first', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'forget_bias', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, 0.0, False, True, None, None, None, None, 1.0, 'float32', 'basic_lstm')), ('document', 'fe4d0c3c55a162b8cfe10b05fabb7ce4'))
paddle.fluid.contrib.ctr_metric_bundle (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', 'b68d12366896c41065fc3738393da2aa'))
paddle.fluid.dygraph.Layer ('paddle.fluid.dygraph.layers.Layer', ('document', 'a889d5affd734ede273e94d4257163ab'))
paddle.fluid.dygraph.Layer.__init__ (ArgSpec(args=['self', 'name_scope', 'dtype'], varargs=None, keywords=None, defaults=(VarType.FP32,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.dygraph.Layer.add_parameter (ArgSpec(args=['self', 'name', 'parameter'], varargs=None, keywords=None, defaults=None), ('document', 'f35ab374c7d5165c3daf3bd64a5a2ec1'))
......
......@@ -33,11 +33,53 @@ limitations under the License. */
#include "io/shell.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
void RecordCandidateList::ReSize(size_t length) {
_mutex.lock();
_capacity = length;
CHECK(_capacity > 0); // NOLINT
_candidate_list.clear();
_candidate_list.resize(_capacity);
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
}
void RecordCandidateList::ReInit() {
_mutex.lock();
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
}
void RecordCandidateList::AddAndGet(const Record& record,
RecordCandidate* result) {
_mutex.lock();
size_t index = 0;
++_total_size;
auto fleet_ptr = FleetWrapper::GetInstance();
if (!_full) {
_candidate_list[_cur_size++] = record;
_full = (_cur_size == _capacity);
} else {
CHECK(_cur_size == _capacity);
index = fleet_ptr->LocalRandomEngine()() % _total_size;
if (index < _capacity) {
_candidate_list[index] = record;
}
}
index = fleet_ptr->LocalRandomEngine()() % _cur_size;
*result = _candidate_list[index];
_mutex.unlock();
}
void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <utility>
#include <vector>
......@@ -427,6 +428,41 @@ struct Record {
std::string ins_id_;
};
struct RecordCandidate {
std::string ins_id_;
std::unordered_multimap<uint16_t, FeatureKey> feas;
RecordCandidate& operator=(const Record& rec) {
feas.clear();
ins_id_ = rec.ins_id_;
for (auto& fea : rec.uint64_feasigns_) {
feas.insert({fea.slot(), fea.sign()});
}
return *this;
}
};
class RecordCandidateList {
public:
RecordCandidateList() = default;
RecordCandidateList(const RecordCandidateList&) = delete;
RecordCandidateList& operator=(const RecordCandidateList&) = delete;
void ReSize(size_t length);
void ReInit();
void AddAndGet(const Record& record, RecordCandidate* result);
private:
size_t _capacity = 0;
std::mutex _mutex;
bool _full = false;
size_t _cur_size = 0;
size_t _total_size = 0;
std::vector<RecordCandidate> _candidate_list;
};
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
const FeatureKey& fk) {
......
......@@ -114,6 +114,14 @@ void DatasetImpl<T>::SetMergeByInsId(
keep_unmerged_ins_ = keep_unmerged_ins;
}
template <typename T>
void DatasetImpl<T>::SetFeaEval(bool fea_eval, int record_candidate_size) {
slots_shuffle_fea_eval_ = fea_eval;
slots_shuffle_rclist_.ReSize(record_candidate_size);
VLOG(3) << "SetFeaEval fea eval mode: " << fea_eval
<< " with record candidate size: " << record_candidate_size;
}
template <typename T>
std::vector<paddle::framework::DataFeed*> DatasetImpl<T>::GetReaders() {
std::vector<paddle::framework::DataFeed*> ret;
......@@ -646,5 +654,167 @@ void MultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId end";
}
void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
int debug_erase_cnt = 0;
int debug_push_cnt = 0;
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
slots_shuffle_rclist_.ReInit();
for (const auto& rec : slots_shuffle_original_data_) {
RecordCandidate rand_rec;
Record new_rec = rec;
slots_shuffle_rclist_.AddAndGet(rec, &rand_rec);
for (auto it = new_rec.uint64_feasigns_.begin();
it != new_rec.uint64_feasigns_.end();) {
if (slots_to_replace.find(it->slot()) != slots_to_replace.end()) {
it = new_rec.uint64_feasigns_.erase(it);
debug_erase_cnt += 1;
} else {
++it;
}
}
for (auto slot : slots_to_replace) {
auto range = rand_rec.feas.equal_range(slot);
for (auto it = range.first; it != range.second; ++it) {
new_rec.uint64_feasigns_.push_back({it->second, it->first});
debug_push_cnt += 1;
}
}
result->push_back(std::move(new_rec));
}
VLOG(2) << "erase feasign num: " << debug_erase_cnt
<< " repush feasign num: " << debug_push_cnt;
}
// slots shuffle to input_channel_ with needed-shuffle slots
void MultiSlotDataset::SlotsShuffle(
const std::set<std::string>& slots_to_replace) {
int out_channel_size = 0;
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
out_channel_size += multi_output_channel_[i]->Size();
}
} else {
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
out_channel_size += multi_consume_channel_[i]->Size();
}
}
VLOG(2) << "DatasetImpl<T>::SlotsShuffle() begin with input channel size: "
<< input_channel_->Size()
<< " output channel size: " << out_channel_size;
if (!slots_shuffle_fea_eval_) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end,"
"fea eval mode off, need to set on for slots shuffle";
return;
}
if ((!input_channel_ || input_channel_->Size() == 0) &&
slots_shuffle_original_data_.size() == 0 && out_channel_size == 0) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end, no data to slots shuffle";
return;
}
platform::Timer timeline;
timeline.Start();
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::set<uint16_t> index_slots;
for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) {
std::string cur_slot = multi_slot_desc.slots(i).name();
if (slots_to_replace.find(cur_slot) != slots_to_replace.end()) {
index_slots.insert(i);
}
}
if (slots_shuffle_original_data_.size() == 0) {
// before first slots shuffle, instances could be in
// input_channel, oupput_channel or consume_channel
if (input_channel_ && input_channel_->Size() != 0) {
slots_shuffle_original_data_.reserve(input_channel_->Size());
input_channel_->Close();
input_channel_->ReadAll(slots_shuffle_original_data_);
} else {
CHECK(out_channel_size > 0); // NOLINT
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
std::vector<Record> vec_data;
multi_output_channel_[i]->Close();
multi_output_channel_[i]->ReadAll(vec_data);
slots_shuffle_original_data_.reserve(
slots_shuffle_original_data_.size() + vec_data.size());
slots_shuffle_original_data_.insert(
slots_shuffle_original_data_.end(),
std::make_move_iterator(vec_data.begin()),
std::make_move_iterator(vec_data.end()));
vec_data.clear();
vec_data.shrink_to_fit();
multi_output_channel_[i]->Clear();
}
} else {
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
std::vector<Record> vec_data;
multi_consume_channel_[i]->Close();
multi_consume_channel_[i]->ReadAll(vec_data);
slots_shuffle_original_data_.reserve(
slots_shuffle_original_data_.size() + vec_data.size());
slots_shuffle_original_data_.insert(
slots_shuffle_original_data_.end(),
std::make_move_iterator(vec_data.begin()),
std::make_move_iterator(vec_data.end()));
vec_data.clear();
vec_data.shrink_to_fit();
multi_consume_channel_[i]->Clear();
}
}
}
} else {
// if already have original data for slots shuffle, clear channel
input_channel_->Clear();
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
if (!multi_output_channel_[i]) {
continue;
}
multi_output_channel_[i]->Clear();
}
} else {
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
if (!multi_consume_channel_[i]) {
continue;
}
multi_consume_channel_[i]->Clear();
}
}
}
int end_size = 0;
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
if (!multi_output_channel_[i]) {
continue;
}
end_size += multi_output_channel_[i]->Size();
}
} else {
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
if (!multi_consume_channel_[i]) {
continue;
}
end_size += multi_consume_channel_[i]->Size();
}
}
CHECK(input_channel_->Size() == 0)
<< "input channel should be empty before slots shuffle";
std::vector<Record> random_data;
random_data.clear();
// get slots shuffled random_data
GetRandomData(index_slots, &random_data);
input_channel_->Open();
input_channel_->Write(std::move(random_data));
random_data.clear();
random_data.shrink_to_fit();
input_channel_->Close();
timeline.Pause();
VLOG(2) << "DatasetImpl<T>::SlotsShuffle() end"
<< ", memory data size for slots shuffle=" << input_channel_->Size()
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
}
} // end namespace framework
} // end namespace paddle
......@@ -17,6 +17,7 @@
#include <fstream>
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
......@@ -61,6 +62,8 @@ class Dataset {
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins) = 0;
// set fea eval mode
virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
// get file list
virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num
......@@ -94,6 +97,10 @@ class Dataset {
virtual void LocalShuffle() = 0;
// global shuffle data
virtual void GlobalShuffle() = 0;
// for slots shuffle
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) = 0;
// create readers
virtual void CreateReaders() = 0;
// destroy readers
......@@ -130,6 +137,7 @@ class DatasetImpl : public Dataset {
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins);
virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
......@@ -150,6 +158,9 @@ class DatasetImpl : public Dataset {
virtual void ReleaseMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {}
virtual void CreateReaders();
virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
......@@ -168,6 +179,8 @@ class DatasetImpl : public Dataset {
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in output_channel, else consume_channel
int cur_channel_;
std::vector<T> slots_shuffle_original_data_;
RecordCandidateList slots_shuffle_rclist_;
int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_;
int trainer_num_;
......@@ -184,6 +197,7 @@ class DatasetImpl : public Dataset {
bool keep_unmerged_ins_;
int min_merge_size_;
std::vector<std::string> merge_slots_list_;
bool slots_shuffle_fea_eval_ = false;
};
// use std::vector<MultiSlotType> or Record as data type
......@@ -191,6 +205,9 @@ class MultiSlotDataset : public DatasetImpl<Record> {
public:
MultiSlotDataset() {}
virtual void MergeByInsId();
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual ~MultiSlotDataset() {}
};
......
......@@ -512,6 +512,57 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#endif
}
double FleetWrapper::GetCacheThreshold() {
#ifdef PADDLE_WITH_PSLIB
double cache_threshold = 0.0;
auto ret = pslib_ptr_->_worker_ptr->flush();
ret.wait();
ret = pslib_ptr_->_worker_ptr->get_cache_threshold(0, cache_threshold);
ret.wait();
if (cache_threshold < 0) {
LOG(ERROR) << "get cache threshold failed";
exit(-1);
}
return cache_threshold;
#else
VLOG(0) << "FleetWrapper::GetCacheThreshold does nothing when no pslib";
return 0.0;
#endif
}
void FleetWrapper::CacheShuffle(int table_id, const std::string& path,
const int mode, const double cache_threshold) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->cache_shuffle(
0, path, std::to_string(mode), std::to_string(cache_threshold));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "cache shuffle failed";
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::CacheShuffle does nothing when no pslib";
#endif
}
int32_t FleetWrapper::SaveCache(int table_id, const std::string& path,
const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save_cache(0, path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "table save cache failed";
exit(-1);
}
return feasign_cnt;
#else
VLOG(0) << "FleetWrapper::SaveCache does nothing when no pslib";
return -1;
#endif
}
void FleetWrapper::ShrinkSparseTable(int table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->shrink(table_id);
......
......@@ -148,7 +148,13 @@ class FleetWrapper {
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);
double GetCacheThreshold();
void CacheShuffle(int table_id, const std::string& path, const int mode,
const double cache_threshold);
int32_t SaveCache(int table_id, const std::string& path, const int mode);
void ClearModel();
void ShrinkSparseTable(int table_id);
void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay,
......
......@@ -103,6 +103,10 @@ void BindDataset(py::module* m) {
.def("set_merge_by_lineid", &framework::Dataset::SetMergeByInsId,
py::call_guard<py::gil_scoped_release>())
.def("merge_by_lineid", &framework::Dataset::MergeByInsId,
py::call_guard<py::gil_scoped_release>())
.def("slots_shuffle", &framework::Dataset::SlotsShuffle,
py::call_guard<py::gil_scoped_release>())
.def("set_fea_eval", &framework::Dataset::SetFeaEval,
py::call_guard<py::gil_scoped_release>());
}
......
......@@ -49,6 +49,9 @@ void BindFleetWrapper(py::module* m) {
.def("init_worker", &framework::FleetWrapper::InitWorker)
.def("init_model", &framework::FleetWrapper::PushDenseParamSync)
.def("save_model", &framework::FleetWrapper::SaveModel)
.def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &framework::FleetWrapper::CacheShuffle)
.def("save_cache", &framework::FleetWrapper::SaveCache)
.def("load_model", &framework::FleetWrapper::LoadModel)
.def("clear_model", &framework::FleetWrapper::ClearModel)
.def("stop_server", &framework::FleetWrapper::StopServer)
......
......@@ -16,8 +16,12 @@ from __future__ import print_function
from . import nn
from .nn import *
from .rnn_impl import *
from . import metric_op
from .metric_op import *
__all__ = []
__all__ += nn.__all__
__all__ += rnn_impl.__all__
__all__ += metric_op.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contrib layers just related to metric.
"""
from __future__ import print_function
import warnings
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.initializer import Normal, Constant
from paddle.fluid.framework import Variable
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layers import nn
__all__ = ['ctr_metric_bundle']
def ctr_metric_bundle(input, label):
"""
ctr related metric layer
This function help compute the ctr related metrics: RMSE, MAE, predicted_ctr, q_value.
To compute the final values of these metrics, we should do following computations using
total instance number:
MAE = local_abserr / instance number
RMSE = sqrt(local_sqrerr / instance number)
predicted_ctr = local_prob / instance number
q = local_q / instance number
Note that if you are doing distribute job, you should all reduce these metrics and instance
number first
Args:
input(Variable): A floating-point 2D Variable, values are in the range
[0, 1]. Each row is sorted in descending order. This
input should be the output of topk. Typically, this
Variable indicates the probability of each label.
label(Variable): A 2D int Variable indicating the label of the training
data. The height is batch size and width is always 1.
Returns:
local_sqrerr(Variable): Local sum of squared error
local_abserr(Variable): Local sum of abs error
local_prob(Variable): Local sum of predicted ctr
local_q(Variable): Local sum of q value
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int32")
predict = fluid.layers.sigmoid(fluid.layers.fc(input=data, size=1))
auc_out = fluid.contrib.layers.ctr_metric_bundle(input=predict, label=label)
"""
assert input.shape == label.shape
helper = LayerHelper("ctr_metric_bundle", **locals())
local_abserr = helper.create_global_variable(
persistable=True, dtype='float32', shape=[1])
local_sqrerr = helper.create_global_variable(
persistable=True, dtype='float32', shape=[1])
local_prob = helper.create_global_variable(
persistable=True, dtype='float32', shape=[1])
local_q = helper.create_global_variable(
persistable=True, dtype='float32', shape=[1])
local_pos_num = helper.create_global_variable(
persistable=True, dtype='float32', shape=[1])
local_ins_num = helper.create_global_variable(
persistable=True, dtype='float32', shape=[1])
tmp_res_elesub = helper.create_global_variable(
persistable=False, dtype='float32', shape=[-1])
tmp_res_sigmoid = helper.create_global_variable(
persistable=False, dtype='float32', shape=[-1])
tmp_ones = helper.create_global_variable(
persistable=False, dtype='float32', shape=[-1])
batch_prob = helper.create_global_variable(
persistable=False, dtype='float32', shape=[1])
batch_abserr = helper.create_global_variable(
persistable=False, dtype='float32', shape=[1])
batch_sqrerr = helper.create_global_variable(
persistable=False, dtype='float32', shape=[1])
batch_q = helper.create_global_variable(
persistable=False, dtype='float32', shape=[1])
batch_pos_num = helper.create_global_variable(
persistable=False, dtype='float32', shape=[1])
batch_ins_num = helper.create_global_variable(
persistable=False, dtype='float32', shape=[1])
for var in [
local_abserr, batch_abserr, local_sqrerr, batch_sqrerr, local_prob,
batch_prob, local_q, batch_q, batch_pos_num, batch_ins_num,
local_pos_num, local_ins_num
]:
helper.set_variable_initializer(
var, Constant(
value=0.0, force_cpu=True))
helper.append_op(
type="elementwise_sub",
inputs={"X": [input],
"Y": [label]},
outputs={"Out": [tmp_res_elesub]})
helper.append_op(
type="squared_l2_norm",
inputs={"X": [tmp_res_elesub]},
outputs={"Out": [batch_sqrerr]})
helper.append_op(
type="elementwise_add",
inputs={"X": [batch_sqrerr],
"Y": [local_sqrerr]},
outputs={"Out": [local_sqrerr]})
helper.append_op(
type="l1_norm",
inputs={"X": [tmp_res_elesub]},
outputs={"Out": [batch_abserr]})
helper.append_op(
type="elementwise_add",
inputs={"X": [batch_abserr],
"Y": [local_abserr]},
outputs={"Out": [local_abserr]})
helper.append_op(
type="reduce_sum", inputs={"X": [input]},
outputs={"Out": [batch_prob]})
helper.append_op(
type="elementwise_add",
inputs={"X": [batch_prob],
"Y": [local_prob]},
outputs={"Out": [local_prob]})
helper.append_op(
type="sigmoid",
inputs={"X": [input]},
outputs={"Out": [tmp_res_sigmoid]})
helper.append_op(
type="reduce_sum",
inputs={"X": [tmp_res_sigmoid]},
outputs={"Out": [batch_q]})
helper.append_op(
type="elementwise_add",
inputs={"X": [batch_q],
"Y": [local_q]},
outputs={"Out": [local_q]})
helper.append_op(
type="reduce_sum",
inputs={"X": [label]},
outputs={"Out": [batch_pos_num]})
helper.append_op(
type="elementwise_add",
inputs={"X": [batch_pos_num],
"Y": [local_pos_num]},
outputs={"Out": [local_pos_num]})
helper.append_op(
type='fill_constant_batch_size_like',
inputs={"Input": label},
outputs={'Out': [tmp_ones]},
attrs={
'shape': [-1, 1],
'dtype': tmp_ones.dtype,
'value': float(1.0),
})
helper.append_op(
type="reduce_sum",
inputs={"X": [tmp_ones]},
outputs={"Out": [batch_ins_num]})
helper.append_op(
type="elementwise_add",
inputs={"X": [batch_ins_num],
"Y": [local_ins_num]},
outputs={"Out": [local_ins_num]})
return local_sqrerr, local_abserr, local_prob, local_q, local_pos_num, local_ins_num
......@@ -91,6 +91,51 @@ class DatasetBase(object):
"""
self.proto_desc.pipe_command = pipe_command
def set_fea_eval(self, record_candidate_size, fea_eval=True):
"""
set fea eval mode for slots shuffle to debug the importance level of
slots(features), fea_eval need to be set True for slots shuffle.
Args:
record_candidate_size(int): size of instances candidate to shuffle
one slot
fea_eval(bool): wheather enable fea eval mode to enable slots shuffle.
default is True.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_fea_eval(1000000, True)
"""
if fea_eval:
self.dataset.set_fea_eval(fea_eval, record_candidate_size)
self.fea_eval = fea_eval
def slots_shuffle(self, slots):
"""
Slots Shuffle
Slots Shuffle is a shuffle method in slots level, which is usually used
in sparse feature with large scale of instances. To compare the metric, i.e.
auc while doing slots shuffle on one or several slots with baseline to
evaluate the importance level of slots(features).
Args:
slots(list[string]): the set of slots(string) to do slots shuffle.
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_lineid()
#suppose there is a slot 0
dataset.slots_shuffle(['0'])
"""
if self.fea_eval:
slots_set = set(slots)
self.dataset.slots_shuffle(slots_set)
def set_batch_size(self, batch_size):
"""
Set batch size. Will be effective during training
......
......@@ -15,7 +15,6 @@ import os
import sys
from optimizer_factory import *
from google.protobuf import text_format
import paddle.fluid as fluid
from paddle.fluid.framework import Program
......@@ -212,6 +211,45 @@ class PSLib(Fleet):
self._fleet_ptr.save_model(dirname, mode)
self._role_maker._barrier_worker()
def save_cache_model(self, executor, dirname, main_program=None, **kwargs):
"""
save sparse cache table,
when using fleet, it will save sparse cache table
Args:
dirname(str): save path. It can be hdfs/afs path or local path
main_program(Program): fluid program, default None
kwargs: use define property, current support following
mode(int): define for feature extension in the future,
currently no use, will pass a default value 0
Example:
.. code-block:: python
>>> fleet.save_cache_model(None, dirname="/you/path/to/model", mode = 0)
"""
mode = kwargs.get("mode", 0)
self._fleet_ptr.client_flush()
self._role_maker._barrier_worker()
cache_threshold = 0.0
if self._role_maker.is_first_worker():
cache_threshold = self._fleet_ptr.get_cache_threshold()
#check cache threshold right or not
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.cache_shuffle(0, dirname, mode, cache_threshold)
self._role_maker._barrier_worker()
feasign_num = -1
if self._role_maker.is_first_worker():
feasign_num = self._fleet_ptr.save_cache(0, dirname, mode)
self._role_maker._barrier_worker()
return feasign_num
def shrink_sparse_table(self):
"""
shrink cvm of all sparse embedding in pserver, the decay rate
......
......@@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle',
syntax='proto2',
serialized_pb=_b(
'\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc4\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xf0\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\x12 \n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01'
'\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc0\x02\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\'\n\x19\x65nable_sparse_table_cache\x18\x07 \x01(\x08:\x04true\x12(\n\x17sparse_table_cache_rate\x18\x08 \x01(\x01:\x07\x30.00055\x12\'\n\x1bsparse_table_cache_file_num\x18\t \x01(\r:\x02\x31\x36\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xf0\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\x12 \n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\x9c\x03\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x12\x1b\n\x17PS_SAVE_ONE_CACHE_TABLE\x10\r\x12\x1a\n\x16PS_GET_CACHE_THRESHOLD\x10\x0e\x12\x14\n\x10PS_CACHE_SHUFFLE\x10\x0f\x12\x0e\n\nPS_S2S_MSG\x10\x65\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -49,8 +49,8 @@ _TABLETYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=3528,
serialized_end=3580, )
serialized_start=3652,
serialized_end=3704, )
_sym_db.RegisterEnumDescriptor(_TABLETYPE)
TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE)
......@@ -131,11 +131,31 @@ _PSCMDID = _descriptor.EnumDescriptor(
_descriptor.EnumValueDescriptor(
name='PS_STOP_SERVER', index=12, number=12, options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_SAVE_ONE_CACHE_TABLE',
index=13,
number=13,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_GET_CACHE_THRESHOLD',
index=14,
number=14,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_CACHE_SHUFFLE',
index=15,
number=15,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_S2S_MSG', index=16, number=101, options=None, type=None),
],
containing_type=None,
options=None,
serialized_start=3583,
serialized_end=3900, )
serialized_start=3707,
serialized_end=4119, )
_sym_db.RegisterEnumDescriptor(_PSCMDID)
PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID)
......@@ -154,6 +174,10 @@ PS_CLEAR_ONE_TABLE = 9
PS_CLEAR_ALL_TABLE = 10
PS_PUSH_DENSE_PARAM = 11
PS_STOP_SERVER = 12
PS_SAVE_ONE_CACHE_TABLE = 13
PS_GET_CACHE_THRESHOLD = 14
PS_CACHE_SHUFFLE = 15
PS_S2S_MSG = 101
_FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor(
name='FsApiType',
......@@ -168,8 +192,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=3496,
serialized_end=3526, )
serialized_start=3620,
serialized_end=3650, )
_sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE)
_PSPARAMETER = _descriptor.Descriptor(
......@@ -1057,6 +1081,54 @@ _TABLEPARAMETER = _descriptor.Descriptor(
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='enable_sparse_table_cache',
full_name='paddle.TableParameter.enable_sparse_table_cache',
index=6,
number=7,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=True,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='sparse_table_cache_rate',
full_name='paddle.TableParameter.sparse_table_cache_rate',
index=7,
number=8,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0.00055),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='sparse_table_cache_file_num',
full_name='paddle.TableParameter.sparse_table_cache_file_num',
index=8,
number=9,
type=13,
cpp_type=3,
label=1,
has_default_value=True,
default_value=16,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
......@@ -1067,7 +1139,7 @@ _TABLEPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[],
serialized_start=1573,
serialized_end=1769, )
serialized_end=1893, )
_TABLEACCESSORPARAMETER = _descriptor.Descriptor(
name='TableAccessorParameter',
......@@ -1213,8 +1285,8 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1772,
serialized_end=2141, )
serialized_start=1896,
serialized_end=2265, )
_DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor(
name='DownpourTableAccessorParameter',
......@@ -1360,8 +1432,8 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2144,
serialized_end=2384, )
serialized_start=2268,
serialized_end=2508, )
_TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor(
name='TableAccessorSaveParameter',
......@@ -1427,8 +1499,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2386,
serialized_end=2469, )
serialized_start=2510,
serialized_end=2593, )
_PSREQUESTMESSAGE = _descriptor.Descriptor(
name='PsRequestMessage',
......@@ -1526,8 +1598,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2471,
serialized_end=2572, )
serialized_start=2595,
serialized_end=2696, )
_SPARSESGDRULEPARAMETER = _descriptor.Descriptor(
name='SparseSGDRuleParameter',
......@@ -1609,8 +1681,8 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2574,
serialized_end=2693, )
serialized_start=2698,
serialized_end=2817, )
_DENSESGDRULEPARAMETER = _descriptor.Descriptor(
name='DenseSGDRuleParameter',
......@@ -1708,8 +1780,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2696,
serialized_end=2921, )
serialized_start=2820,
serialized_end=3045, )
_ADAMSGDPARAMETER = _descriptor.Descriptor(
name='AdamSGDParameter',
......@@ -1807,8 +1879,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2924,
serialized_end=3058, )
serialized_start=3048,
serialized_end=3182, )
_NAIVESGDPARAMETER = _descriptor.Descriptor(
name='NaiveSGDParameter',
......@@ -1858,8 +1930,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3060,
serialized_end=3126, )
serialized_start=3184,
serialized_end=3250, )
_SUMMARYSGDPARAMETER = _descriptor.Descriptor(
name='SummarySGDParameter',
......@@ -1893,8 +1965,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3128,
serialized_end=3187, )
serialized_start=3252,
serialized_end=3311, )
_MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor(
name='MovingAverageRuleParameter',
......@@ -1928,8 +2000,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3189,
serialized_end=3235, )
serialized_start=3313,
serialized_end=3359, )
_PSRESPONSEMESSAGE = _descriptor.Descriptor(
name='PsResponseMessage',
......@@ -1995,8 +2067,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3237,
serialized_end=3310, )
serialized_start=3361,
serialized_end=3434, )
_FSCLIENTPARAMETER = _descriptor.Descriptor(
name='FsClientParameter',
......@@ -2126,8 +2198,8 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3313,
serialized_end=3526, )
serialized_start=3437,
serialized_end=3650, )
_PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER
_PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER
......
......@@ -109,6 +109,8 @@ class TestDataset(unittest.TestCase):
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
dataset.set_fea_eval(10000, True)
dataset.slots_shuffle(["slot1"])
dataset.local_shuffle()
exe = fluid.Executor(fluid.CPUPlace())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册