未验证 提交 893ea7e0 编写于 作者: T Thunderbrook 提交者: GitHub

[cherry-pick] find lookup table in order & support dump param (#21347)

* support dump param of model into afs (#20302)

* support dump param to afs
test=develop

* code style
test=develop

* code style
test=develop

* dump param
test=develop

* dump param
test=develop

* dump param
test=develop

* dump param
test=develop

* find lookup table in order (#20932)

test=develop

* cherry-pick
test=develop

* solve pslib core in stop worker
test=develop

* print table stat info for pslib
test=develop
上级 5dbe9e59
......@@ -105,7 +105,10 @@ class PullDenseWorker {
// should incorporate different type of device
class DeviceWorker {
public:
DeviceWorker() { use_cvm_ = false; }
DeviceWorker() {
no_cvm_ = true;
use_cvm_ = false;
}
virtual ~DeviceWorker() {}
virtual void Initialize(const TrainerDesc& desc) = 0;
virtual void SetDeviceIndex(int tid) = 0;
......@@ -135,6 +138,7 @@ class DeviceWorker {
int64_t batch_num_;
FetchConfig fetch_config_;
bool use_cvm_;
bool no_cvm_;
};
class CPUWorkerBase : public DeviceWorker {
......@@ -203,6 +207,8 @@ class DownpourWorker : public HogwildWorker {
void CopyDenseVars();
private:
bool need_dump_param_;
std::vector<std::string> dump_param_;
bool need_to_push_dense_;
bool need_dump_field_;
bool dump_slot_;
......
......@@ -75,6 +75,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config();
use_cvm_ = desc.use_cvm();
// for sparse value accessor, embedding only
no_cvm_ = desc.no_cvm();
scale_datanorm_ = desc.scale_datanorm();
dump_slot_ = desc.dump_slot();
dump_fields_.resize(desc.dump_fields_size());
......@@ -82,6 +84,14 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
dump_fields_[i] = desc.dump_fields(i);
}
adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
need_dump_param_ = false;
dump_param_.resize(desc.dump_param_size());
for (int i = 0; i < desc.dump_param_size(); ++i) {
dump_param_[i] = desc.dump_param(i);
}
if (desc.dump_param_size() != 0) {
need_dump_param_ = true;
}
for (int i = 0; i < desc.check_nan_var_names_size(); ++i) {
check_nan_var_names_.push_back(desc.check_nan_var_names(i));
}
......@@ -186,7 +196,26 @@ bool CheckValidOutput(LoDTensor* tensor, int batch_size) {
return true;
}
void DownpourWorker::DumpParam() {
std::string os;
for (auto& param : dump_param_) {
os.clear();
os = param;
Variable* var = thread_scope_->FindVar(param);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t len = tensor->numel();
os += PrintLodTensor(tensor, 0, len);
writer_ << os;
}
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
if (no_cvm_) {
return;
}
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
......@@ -288,7 +317,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
int nid_ins_index = 0;
for (int index = 0; index < len; ++index) {
if (use_cvm_) {
if (use_cvm_ || no_cvm_) {
if (ids[index] == 0u) {
memcpy(ptr + table.emb_dim() * index, init_value.data(),
sizeof(float) * table.emb_dim());
......@@ -657,7 +686,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_, &sparse_push_keys_[tid]);
dump_slot_, &sparse_push_keys_[tid], no_cvm_);
timeline.Pause();
push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
......@@ -882,7 +911,7 @@ void DownpourWorker::TrainFiles() {
*thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_, &sparse_push_keys_[tid]);
dump_slot_, &sparse_push_keys_[tid], no_cvm_);
}
}
......@@ -977,6 +1006,9 @@ void DownpourWorker::TrainFiles() {
}
writer_ << ars[i];
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam();
}
}
PrintFetchVars();
......
......@@ -91,6 +91,13 @@ void FleetWrapper::StopServer() {
#endif
}
void FleetWrapper::FinalizeWorker() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to finalize worker";
pslib_ptr_->finalize_worker();
#endif
}
uint64_t FleetWrapper::RunServer() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to run server";
......@@ -303,7 +310,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys) {
std::vector<uint64_t>* sparse_push_keys, const bool no_cvm) {
#ifdef PADDLE_WITH_PSLIB
int offset = 2;
int slot_offset = 0;
......@@ -314,6 +321,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
offset = 0;
grad_dim = emb_dim - 2;
}
if (no_cvm) {
offset = 0;
grad_dim = emb_dim;
}
if (dump_slot) {
slot_offset = 1;
show_index = 1;
......@@ -370,12 +381,12 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
}
sparse_push_keys->push_back(ids[id_idx]);
CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
if (use_cvm) {
if (use_cvm || no_cvm) {
memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g,
sizeof(float) * emb_dim);
} else {
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g,
sizeof(float) * emb_dim);
(*push_values)[fea_idx][show_index] = 1.0f;
......@@ -549,6 +560,19 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#endif
}
void FleetWrapper::PrintTableStat(const uint64_t table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->print_table_stat(table_id);
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
LOG(ERROR) << "print table stat failed";
}
#else
VLOG(0) << "FleetWrapper::PrintTableStat does nothing when no pslib";
#endif
}
double FleetWrapper::GetCacheThreshold(int table_id) {
#ifdef PADDLE_WITH_PSLIB
double cache_threshold = 0.0;
......
......@@ -124,7 +124,7 @@ class FleetWrapper {
std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys);
std::vector<uint64_t>* sparse_push_keys, const bool no_cvm);
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
......@@ -147,6 +147,8 @@ class FleetWrapper {
int index);
// stop server
void StopServer();
// finalize worker to make worker can be stop
void FinalizeWorker();
// run server
uint64_t RunServer();
// gather server ip
......@@ -165,6 +167,8 @@ class FleetWrapper {
std::string model_path, std::string model_proto_file,
std::vector<std::string> table_var_list,
bool load_combine);
void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
......
......@@ -105,7 +105,6 @@ class DistMultiTrainer : public MultiTrainer {
bool need_dump_field_;
std::string dump_fields_path_;
std::string dump_converter_;
std::vector<std::string> dump_fields_;
int mpi_rank_;
int mpi_size_;
int dump_file_num_;
......
......@@ -40,12 +40,14 @@ message TrainerDesc {
repeated string dump_fields = 13;
optional string dump_converter = 14;
repeated string dump_param = 15;
optional int32 mpi_size = 16 [ default = -1 ];
optional int32 dump_file_num = 17 [ default = 16 ];
repeated string check_nan_var_names = 18;
optional CopyTableConfig copy_table_config = 19;
// adjust ins weight
optional AdjustInsWeightConfig adjust_ins_weight_config = 20;
optional bool no_cvm = 21 [ default = false ];
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
......
......@@ -55,6 +55,7 @@ void BindFleetWrapper(py::module* m) {
.def("load_model", &framework::FleetWrapper::LoadModel)
.def("clear_model", &framework::FleetWrapper::ClearModel)
.def("stop_server", &framework::FleetWrapper::StopServer)
.def("finalize_worker", &framework::FleetWrapper::FinalizeWorker)
.def("gather_servers", &framework::FleetWrapper::GatherServers)
.def("gather_clients", &framework::FleetWrapper::GatherClients)
.def("get_clients_info", &framework::FleetWrapper::GetClientsInfo)
......@@ -62,6 +63,7 @@ void BindFleetWrapper(py::module* m) {
&framework::FleetWrapper::CreateClient2ClientConnection)
.def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable)
.def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable)
.def("print_table_stat", &framework::FleetWrapper::PrintTableStat)
.def("client_flush", &framework::FleetWrapper::ClientFlush)
.def("load_from_paddle_model",
&framework::FleetWrapper::LoadFromPaddleModel)
......
......@@ -160,7 +160,8 @@ class DownpourSGD(DeviceWorker):
.sparse_table[i].slot_value)
sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[
i].slot_gradient)
if opt_info["use_cvm"]:
if opt_info["use_cvm"] or "no_cvm" in opt_info and opt_info[
"no_cvm"] == True:
sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
i].accessor.fea_dim
......
......@@ -182,6 +182,10 @@ class PSLib(Fleet):
destroyed when stop() is called.
"""
self._role_maker._barrier_worker()
# all worker should be finalize first
if self._role_maker.is_worker():
self._fleet_ptr.finalize_worker()
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.stop_server()
self._role_maker._barrier_worker()
......@@ -234,6 +238,25 @@ class PSLib(Fleet):
"""
self._fleet_ptr.save_model(dirname)
def print_table_stat(self, table_id):
"""
print stat info of table_id,
format: tableid, feasign size, mf size
Args:
table_id(int): the id of table
Example:
.. code-block:: python
fleet.print_table_stat(0)
"""
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.print_table_stat(table_id)
self._role_maker._barrier_worker()
def save_persistables(self, executor, dirname, main_program=None, **kwargs):
"""
save presistable parameters,
......
......@@ -80,7 +80,8 @@ class DownpourServer(Server):
'sparse_click_coeff', 'sparse_base_threshold', 'sparse_delta_threshold', 'sparse_delta_keep_days', \
'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \
'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \
'sparse_cache_file_num']
'sparse_cache_file_num', 'sparse_beta1_decay_rate', 'sparse_beta2_decay_rate', \
'sparse_ada_epsilon', 'sparse_optimizer']
for key in strategy:
if key not in support_sparse_key_list:
......@@ -108,9 +109,13 @@ class DownpourServer(Server):
table.compress_in_save = strategy.get('sparse_compress_in_save',
True)
table.shard_num = strategy.get('sparse_shard_num', 1000)
# DownpourFeatureValueAccessor: for ctr task, has cvm, embedding and sgd info
# DownpourCtrAccessor : for ctr task, has cvm, slot, embedding and sgd info
# DownpourSparseValueAccessor : for general task, has embedding and sgd info
support_accessor_class = [
'DownpourFeatureValueAccessor', 'DownpourCtrAccessor'
'DownpourFeatureValueAccessor', 'DownpourCtrAccessor',
'DownpourSparseValueAccessor'
]
if strategy.get('sparse_accessor_class') is not None:
accessor_class = strategy.get('sparse_accessor_class')
......@@ -169,6 +174,69 @@ class DownpourServer(Server):
table1.converter = converter
table1.deconverter = deconverter
table2 = table.accessor.table_accessor_save_param.add()
table2.param = 2
table2.converter = converter
table2.deconverter = deconverter
elif accessor_class == 'DownpourSparseValueAccessor':
optimizer_name = strategy.get("sparse_optimizer", "adam")
table.accessor.sparse_commonsgd_param.name = optimizer_name
table.accessor.embedx_dim = strategy.get('sparse_embedx_dim', 8)
table.accessor.fea_dim = int(table.accessor.embedx_dim)
if optimizer_name == "naive":
table.accessor.sparse_commonsgd_param.naive.learning_rate = \
strategy.get('sparse_learning_rate', 0.05)
table.accessor.sparse_commonsgd_param.naive.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
elif optimizer_name == "adagrad":
table.accessor.sparse_commonsgd_param.adagrad.learning_rate = \
strategy.get('sparse_learning_rate', 0.05)
table.accessor.sparse_commonsgd_param.adagrad.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
table.accessor.sparse_commonsgd_param.adagrad.initial_g2sum = strategy.get(
'sparse_initial_g2sum', 3)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
elif optimizer_name == "adam":
table.accessor.sparse_commonsgd_param.adam.learning_rate = \
strategy.get('sparse_learning_rate', 0.001)
table.accessor.sparse_commonsgd_param.adam.initial_range = \
strategy.get('sparse_initial_range', 1e-4)
table.accessor.sparse_commonsgd_param.adam.beta1_decay_rate = strategy.get(
'sparse_beta1_decay_rate', 0.9)
table.accessor.sparse_commonsgd_param.adam.beta2_decay_rate = strategy.get(
'sparse_beta2_decay_rate', 0.999)
table.accessor.sparse_commonsgd_param.adam.ada_epsilon = strategy.get(
'sparse_ada_epsilon', 1e-8)
if strategy.get('sparse_weight_bounds') is None:
table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend(
[-10, 10])
else:
table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend(
strategy.get('sparse_weight_bounds'))
converter = strategy.get(
'sparse_converter',
"(scripts/xbox_compressor_mf.py | bin/xbox_pb_converter)")
deconverter = strategy.get(
'sparse_deconverter',
"(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)"
)
table1 = table.accessor.table_accessor_save_param.add()
table1.param = 1
table1.converter = converter
table1.deconverter = deconverter
table2 = table.accessor.table_accessor_save_param.add()
table2.param = 2
table2.converter = converter
......
......@@ -130,13 +130,22 @@ class DistributedAdam(DistributedOptimizerImplBase):
find multi-sparse-table
"""
table_names = set()
cnt = 0
tmp_list = []
ret_list = []
for loss in losses:
for op in loss.block.program.global_block().ops:
if op.type == "lookup_table":
if op.attr('is_distributed') is True:
table_name = op.input("W")[0]
table_names.add(table_name)
return list(table_names)
if table_name not in table_names:
table_names.add(table_name)
tmp_list.append([table_name, cnt])
cnt += 1
tmp_list.sort(key=lambda k: k[1])
for x in tmp_list:
ret_list.append(x[0])
return ret_list
def _minimize(self,
losses,
......@@ -366,6 +375,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops
opt_info["use_cvm"] = strategy.get("use_cvm", False)
opt_info["no_cvm"] = strategy.get("no_cvm", False)
opt_info["stat_var_names"] = strategy.get("stat_var_names", [])
opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1)
opt_info["check_nan_var_names"] = strategy.get("check_nan_var_names",
......@@ -375,6 +385,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_file_num"] = strategy.get("dump_file_num", 16)
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", [])
if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class == "DownpourCtrAccessor":
opt_info["dump_slot"] = True
......
......@@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle',
syntax='proto2',
serialized_pb=_b(
'\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x03(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc0\x02\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\'\n\x19\x65nable_sparse_table_cache\x18\x07 \x01(\x08:\x04true\x12(\n\x17sparse_table_cache_rate\x18\x08 \x01(\x01:\x07\x30.00055\x12\'\n\x1bsparse_table_cache_file_num\x18\t \x01(\r:\x02\x31\x36\"\xfc\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x13\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r:\x02\x31\x31\x12\x15\n\nembedx_dim\x18\x05 \x01(\r:\x01\x38\x12\x1c\n\x10\x65mbedx_threshold\x18\x06 \x01(\r:\x02\x31\x30\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\x96\x02\n\x1e\x44ownpourTableAccessorParameter\x12\x19\n\x0cnonclk_coeff\x18\x01 \x01(\x02:\x03\x30.1\x12\x16\n\x0b\x63lick_coeff\x18\x02 \x01(\x02:\x01\x31\x12\x1b\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02:\x03\x31.5\x12\x1d\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02:\x04\x30.25\x12\x1b\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02:\x02\x31\x36\x12#\n\x15show_click_decay_rate\x18\x06 \x01(\x02:\x04\x30.98\x12\x1d\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02:\x03\x30.8\x12$\n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02:\x02\x33\x30\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"\x85\x01\n\x16SparseSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\xac\x01\n\x10\x41\x64\x61mSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x35\x65-06\x12 \n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01:\x08\x30.999993\x12\x1e\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01:\x06\x30.9999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01:\x05\x31\x65-08\x12\x1c\n\x0emom_decay_rate\x18\x05 \x01(\x01:\x04\x30.99\"J\n\x11NaiveSGDParameter\x12\x1d\n\rlearning_rate\x18\x01 \x01(\x01:\x06\x30.0002\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\x9c\x03\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x12\x1b\n\x17PS_SAVE_ONE_CACHE_TABLE\x10\r\x12\x1a\n\x16PS_GET_CACHE_THRESHOLD\x10\x0e\x12\x14\n\x10PS_CACHE_SHUFFLE\x10\x0f\x12\x0e\n\nPS_S2S_MSG\x10\x65\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01'
'\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x03(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc0\x02\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\'\n\x19\x65nable_sparse_table_cache\x18\x07 \x01(\x08:\x04true\x12(\n\x17sparse_table_cache_rate\x18\x08 \x01(\x01:\x07\x30.00055\x12\'\n\x1bsparse_table_cache_file_num\x18\t \x01(\r:\x02\x31\x36\"\xc2\x03\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x13\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r:\x02\x31\x31\x12\x15\n\nembedx_dim\x18\x05 \x01(\r:\x01\x38\x12\x1c\n\x10\x65mbedx_threshold\x18\x06 \x01(\r:\x02\x31\x30\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\x12\x44\n\x16sparse_commonsgd_param\x18\t \x01(\x0b\x32$.paddle.SparseCommonSGDRuleParameter\"\x96\x02\n\x1e\x44ownpourTableAccessorParameter\x12\x19\n\x0cnonclk_coeff\x18\x01 \x01(\x02:\x03\x30.1\x12\x16\n\x0b\x63lick_coeff\x18\x02 \x01(\x02:\x01\x31\x12\x1b\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02:\x03\x31.5\x12\x1d\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02:\x04\x30.25\x12\x1b\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02:\x02\x31\x36\x12#\n\x15show_click_decay_rate\x18\x06 \x01(\x02:\x04\x30.98\x12\x1d\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02:\x03\x30.8\x12$\n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02:\x02\x33\x30\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"\x85\x01\n\x16SparseSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xc6\x01\n\x1cSparseCommonSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x32\n\x05naive\x18\x02 \x01(\x0b\x32#.paddle.SparseNaiveSGDRuleParameter\x12\x36\n\x07\x61\x64\x61grad\x18\x03 \x01(\x0b\x32%.paddle.SparseAdagradSGDRuleParameter\x12,\n\x04\x61\x64\x61m\x18\x04 \x01(\x0b\x32\x1e.paddle.SparseAdamSGDParameter\"p\n\x1bSparseNaiveSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x1d\n\rinitial_range\x18\x02 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x03 \x03(\x02\"\x8c\x01\n\x1dSparseAdagradSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xc8\x01\n\x16SparseAdamSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x30.001\x12\x1d\n\rinitial_range\x18\x02 \x01(\x01:\x06\x30.0001\x12\x1d\n\x10\x62\x65ta1_decay_rate\x18\x03 \x01(\x01:\x03\x30.9\x12\x1f\n\x10\x62\x65ta2_decay_rate\x18\x04 \x01(\x01:\x05\x30.999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x05 \x01(\x01:\x05\x31\x65-08\x12\x15\n\rweight_bounds\x18\x06 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\xac\x01\n\x10\x41\x64\x61mSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x35\x65-06\x12 \n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01:\x08\x30.999993\x12\x1e\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01:\x06\x30.9999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01:\x05\x31\x65-08\x12\x1c\n\x0emom_decay_rate\x18\x05 \x01(\x01:\x04\x30.99\"J\n\x11NaiveSGDParameter\x12\x1d\n\rlearning_rate\x18\x01 \x01(\x01:\x06\x30.0002\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xa1\x04\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x12\x1b\n\x17PS_SAVE_ONE_CACHE_TABLE\x10\r\x12\x1a\n\x16PS_GET_CACHE_THRESHOLD\x10\x0e\x12\x14\n\x10PS_CACHE_SHUFFLE\x10\x0f\x12\x11\n\rPS_COPY_TABLE\x10\x10\x12\x1c\n\x18PS_COPY_TABLE_BY_FEASIGN\x10\x11\x12(\n$PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY\x10\x12\x12(\n$PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY\x10\x13\x12\x0e\n\nPS_S2S_MSG\x10\x65\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -49,8 +49,8 @@ _TABLETYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=3762,
serialized_end=3814, )
serialized_start=4493,
serialized_end=4545, )
_sym_db.RegisterEnumDescriptor(_TABLETYPE)
TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE)
......@@ -150,12 +150,32 @@ _PSCMDID = _descriptor.EnumDescriptor(
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_S2S_MSG', index=16, number=101, options=None, type=None),
name='PS_COPY_TABLE', index=16, number=16, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='PS_COPY_TABLE_BY_FEASIGN',
index=17,
number=17,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY',
index=18,
number=18,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY',
index=19,
number=19,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PS_S2S_MSG', index=20, number=101, options=None, type=None),
],
containing_type=None,
options=None,
serialized_start=3817,
serialized_end=4229, )
serialized_start=4548,
serialized_end=5093, )
_sym_db.RegisterEnumDescriptor(_PSCMDID)
PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID)
......@@ -177,6 +197,10 @@ PS_STOP_SERVER = 12
PS_SAVE_ONE_CACHE_TABLE = 13
PS_GET_CACHE_THRESHOLD = 14
PS_CACHE_SHUFFLE = 15
PS_COPY_TABLE = 16
PS_COPY_TABLE_BY_FEASIGN = 17
PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18
PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19
PS_S2S_MSG = 101
_FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor(
......@@ -192,8 +216,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=3730,
serialized_end=3760, )
serialized_start=4461,
serialized_end=4491, )
_sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE)
_PSPARAMETER = _descriptor.Descriptor(
......@@ -1276,6 +1300,22 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor(
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='sparse_commonsgd_param',
full_name='paddle.TableAccessorParameter.sparse_commonsgd_param',
index=8,
number=9,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
......@@ -1286,7 +1326,7 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[],
serialized_start=1896,
serialized_end=2276, )
serialized_end=2346, )
_DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor(
name='DownpourTableAccessorParameter',
......@@ -1432,8 +1472,8 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2279,
serialized_end=2557, )
serialized_start=2349,
serialized_end=2627, )
_TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor(
name='TableAccessorSaveParameter',
......@@ -1499,8 +1539,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2559,
serialized_end=2642, )
serialized_start=2629,
serialized_end=2712, )
_PSREQUESTMESSAGE = _descriptor.Descriptor(
name='PsRequestMessage',
......@@ -1598,8 +1638,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2644,
serialized_end=2745, )
serialized_start=2714,
serialized_end=2815, )
_SPARSESGDRULEPARAMETER = _descriptor.Descriptor(
name='SparseSGDRuleParameter',
......@@ -1681,8 +1721,356 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2748,
serialized_end=2881, )
serialized_start=2818,
serialized_end=2951, )
_SPARSECOMMONSGDRULEPARAMETER = _descriptor.Descriptor(
name='SparseCommonSGDRuleParameter',
full_name='paddle.SparseCommonSGDRuleParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle.SparseCommonSGDRuleParameter.name',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='naive',
full_name='paddle.SparseCommonSGDRuleParameter.naive',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='adagrad',
full_name='paddle.SparseCommonSGDRuleParameter.adagrad',
index=2,
number=3,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='adam',
full_name='paddle.SparseCommonSGDRuleParameter.adam',
index=3,
number=4,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2954,
serialized_end=3152, )
_SPARSENAIVESGDRULEPARAMETER = _descriptor.Descriptor(
name='SparseNaiveSGDRuleParameter',
full_name='paddle.SparseNaiveSGDRuleParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='learning_rate',
full_name='paddle.SparseNaiveSGDRuleParameter.learning_rate',
index=0,
number=1,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0.05),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='initial_range',
full_name='paddle.SparseNaiveSGDRuleParameter.initial_range',
index=1,
number=2,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0.0001),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='weight_bounds',
full_name='paddle.SparseNaiveSGDRuleParameter.weight_bounds',
index=2,
number=3,
type=2,
cpp_type=6,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3154,
serialized_end=3266, )
_SPARSEADAGRADSGDRULEPARAMETER = _descriptor.Descriptor(
name='SparseAdagradSGDRuleParameter',
full_name='paddle.SparseAdagradSGDRuleParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='learning_rate',
full_name='paddle.SparseAdagradSGDRuleParameter.learning_rate',
index=0,
number=1,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0.05),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='initial_g2sum',
full_name='paddle.SparseAdagradSGDRuleParameter.initial_g2sum',
index=1,
number=2,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(3),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='initial_range',
full_name='paddle.SparseAdagradSGDRuleParameter.initial_range',
index=2,
number=3,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0.0001),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='weight_bounds',
full_name='paddle.SparseAdagradSGDRuleParameter.weight_bounds',
index=3,
number=4,
type=2,
cpp_type=6,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3269,
serialized_end=3409, )
_SPARSEADAMSGDPARAMETER = _descriptor.Descriptor(
name='SparseAdamSGDParameter',
full_name='paddle.SparseAdamSGDParameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='learning_rate',
full_name='paddle.SparseAdamSGDParameter.learning_rate',
index=0,
number=1,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0.001),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='initial_range',
full_name='paddle.SparseAdamSGDParameter.initial_range',
index=1,
number=2,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0.0001),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='beta1_decay_rate',
full_name='paddle.SparseAdamSGDParameter.beta1_decay_rate',
index=2,
number=3,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0.9),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='beta2_decay_rate',
full_name='paddle.SparseAdamSGDParameter.beta2_decay_rate',
index=3,
number=4,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(0.999),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ada_epsilon',
full_name='paddle.SparseAdamSGDParameter.ada_epsilon',
index=4,
number=5,
type=1,
cpp_type=5,
label=1,
has_default_value=True,
default_value=float(1e-08),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='weight_bounds',
full_name='paddle.SparseAdamSGDParameter.weight_bounds',
index=5,
number=6,
type=2,
cpp_type=6,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3412,
serialized_end=3612, )
_DENSESGDRULEPARAMETER = _descriptor.Descriptor(
name='DenseSGDRuleParameter',
......@@ -1780,8 +2168,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2884,
serialized_end=3109, )
serialized_start=3615,
serialized_end=3840, )
_ADAMSGDPARAMETER = _descriptor.Descriptor(
name='AdamSGDParameter',
......@@ -1879,8 +2267,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3112,
serialized_end=3284, )
serialized_start=3843,
serialized_end=4015, )
_NAIVESGDPARAMETER = _descriptor.Descriptor(
name='NaiveSGDParameter',
......@@ -1930,8 +2318,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3286,
serialized_end=3360, )
serialized_start=4017,
serialized_end=4091, )
_SUMMARYSGDPARAMETER = _descriptor.Descriptor(
name='SummarySGDParameter',
......@@ -1965,8 +2353,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3362,
serialized_end=3421, )
serialized_start=4093,
serialized_end=4152, )
_MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor(
name='MovingAverageRuleParameter',
......@@ -2000,8 +2388,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3423,
serialized_end=3469, )
serialized_start=4154,
serialized_end=4200, )
_PSRESPONSEMESSAGE = _descriptor.Descriptor(
name='PsResponseMessage',
......@@ -2067,8 +2455,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3471,
serialized_end=3544, )
serialized_start=4202,
serialized_end=4275, )
_FSCLIENTPARAMETER = _descriptor.Descriptor(
name='FsClientParameter',
......@@ -2198,8 +2586,8 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor(
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=3547,
serialized_end=3760, )
serialized_start=4278,
serialized_end=4491, )
_PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER
_PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER
......@@ -2233,6 +2621,14 @@ _TABLEACCESSORPARAMETER.fields_by_name[
'downpour_accessor_param'].message_type = _DOWNPOURTABLEACCESSORPARAMETER
_TABLEACCESSORPARAMETER.fields_by_name[
'table_accessor_save_param'].message_type = _TABLEACCESSORSAVEPARAMETER
_TABLEACCESSORPARAMETER.fields_by_name[
'sparse_commonsgd_param'].message_type = _SPARSECOMMONSGDRULEPARAMETER
_SPARSECOMMONSGDRULEPARAMETER.fields_by_name[
'naive'].message_type = _SPARSENAIVESGDRULEPARAMETER
_SPARSECOMMONSGDRULEPARAMETER.fields_by_name[
'adagrad'].message_type = _SPARSEADAGRADSGDRULEPARAMETER
_SPARSECOMMONSGDRULEPARAMETER.fields_by_name[
'adam'].message_type = _SPARSEADAMSGDPARAMETER
_DENSESGDRULEPARAMETER.fields_by_name['adam'].message_type = _ADAMSGDPARAMETER
_DENSESGDRULEPARAMETER.fields_by_name['naive'].message_type = _NAIVESGDPARAMETER
_DENSESGDRULEPARAMETER.fields_by_name[
......@@ -2266,6 +2662,14 @@ DESCRIPTOR.message_types_by_name[
DESCRIPTOR.message_types_by_name['PsRequestMessage'] = _PSREQUESTMESSAGE
DESCRIPTOR.message_types_by_name[
'SparseSGDRuleParameter'] = _SPARSESGDRULEPARAMETER
DESCRIPTOR.message_types_by_name[
'SparseCommonSGDRuleParameter'] = _SPARSECOMMONSGDRULEPARAMETER
DESCRIPTOR.message_types_by_name[
'SparseNaiveSGDRuleParameter'] = _SPARSENAIVESGDRULEPARAMETER
DESCRIPTOR.message_types_by_name[
'SparseAdagradSGDRuleParameter'] = _SPARSEADAGRADSGDRULEPARAMETER
DESCRIPTOR.message_types_by_name[
'SparseAdamSGDParameter'] = _SPARSEADAMSGDPARAMETER
DESCRIPTOR.message_types_by_name[
'DenseSGDRuleParameter'] = _DENSESGDRULEPARAMETER
DESCRIPTOR.message_types_by_name['AdamSGDParameter'] = _ADAMSGDPARAMETER
......@@ -2438,6 +2842,46 @@ SparseSGDRuleParameter = _reflection.GeneratedProtocolMessageType(
))
_sym_db.RegisterMessage(SparseSGDRuleParameter)
SparseCommonSGDRuleParameter = _reflection.GeneratedProtocolMessageType(
'SparseCommonSGDRuleParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_SPARSECOMMONSGDRULEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.SparseCommonSGDRuleParameter)
))
_sym_db.RegisterMessage(SparseCommonSGDRuleParameter)
SparseNaiveSGDRuleParameter = _reflection.GeneratedProtocolMessageType(
'SparseNaiveSGDRuleParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_SPARSENAIVESGDRULEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.SparseNaiveSGDRuleParameter)
))
_sym_db.RegisterMessage(SparseNaiveSGDRuleParameter)
SparseAdagradSGDRuleParameter = _reflection.GeneratedProtocolMessageType(
'SparseAdagradSGDRuleParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_SPARSEADAGRADSGDRULEPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.SparseAdagradSGDRuleParameter)
))
_sym_db.RegisterMessage(SparseAdagradSGDRuleParameter)
SparseAdamSGDParameter = _reflection.GeneratedProtocolMessageType(
'SparseAdamSGDParameter',
(_message.Message, ),
dict(
DESCRIPTOR=_SPARSEADAMSGDPARAMETER,
__module__='ps_pb2'
# @@protoc_insertion_point(class_scope:paddle.SparseAdamSGDParameter)
))
_sym_db.RegisterMessage(SparseAdamSGDParameter)
DenseSGDRuleParameter = _reflection.GeneratedProtocolMessageType(
'DenseSGDRuleParameter',
(_message.Message, ),
......
......@@ -76,6 +76,9 @@ class TrainerDesc(object):
def _set_use_cvm(self, use_cvm=False):
self.proto_desc.use_cvm = use_cvm
def _set_no_cvm(self, no_cvm=False):
self.proto_desc.no_cvm = no_cvm
def _set_scale_datanorm(self, scale_datanorm=-1):
self.proto_desc.scale_datanorm = scale_datanorm
......@@ -101,6 +104,10 @@ class TrainerDesc(object):
def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter
def _set_dump_param(self, dump_param):
for param in dump_param:
self.proto_desc.dump_param.append(param)
def _set_check_nan_var_names(self, check_nan_var_names):
for var in check_nan_var_names:
self.proto_desc.check_nan_var_names.append(var)
......
......@@ -52,6 +52,8 @@ class TrainerFactory(object):
trainer._set_fleet_desc(opt_info["fleet_desc"])
if opt_info.get("use_cvm") is not None:
trainer._set_use_cvm(opt_info["use_cvm"])
if opt_info.get("no_cvm") is not None:
trainer._set_no_cvm(opt_info["no_cvm"])
if opt_info.get("scale_datanorm") is not None:
trainer._set_scale_datanorm(opt_info["scale_datanorm"])
if opt_info.get("dump_slot") is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册