From 0ed9b051b44792fd2f132701f7d90d2ef719daf5 Mon Sep 17 00:00:00 2001 From: Fan Zhang Date: Tue, 7 Sep 2021 13:16:20 +0800 Subject: [PATCH] [CPU-PSLIB] Add dump_prob & Modify dump_param (#35414) * Add dump_prob & Modify dump_param * Add dump_prob & Modify dump_param * Add dump_prob & Modify dump_param * Add dump_prob & Modify dump_param * Add dump_prob & Modify dump_param --- cmake/external/pslib.cmake | 2 +- paddle/fluid/framework/data_feed.cc | 16 +- paddle/fluid/framework/device_worker.h | 2 + paddle/fluid/framework/dist_multi_trainer.cc | 5 + paddle/fluid/framework/downpour_worker.cc | 68 ++++-- paddle/fluid/framework/trainer_desc.proto | 1 + .../fleet/parameter_server/pslib/ps_pb2.py | 224 ++++++++++++++---- python/paddle/fluid/trainer_desc.py | 3 + python/paddle/fluid/trainer_factory.py | 5 + 9 files changed, 257 insertions(+), 69 deletions(-) diff --git a/cmake/external/pslib.cmake b/cmake/external/pslib.cmake index bdfd335172d..d5af4783555 100644 --- a/cmake/external/pslib.cmake +++ b/cmake/external/pslib.cmake @@ -19,7 +19,7 @@ IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL)) MESSAGE(STATUS "use pre defined download url") SET(PSLIB_VER "0.1.1" CACHE STRING "" FORCE) SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE) - SET(PSLIB_URL "https://pslib.bj.bcebos.com/pslib.tar.gz" CACHE STRING "" FORCE) + SET(PSLIB_URL "https://pslib.bj.bcebos.com/pslib_1.8.5.tar.gz" CACHE STRING "" FORCE) ENDIF() MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}") SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib") diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 41c59c5893f..efe9e2e6576 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -775,8 +775,12 @@ void MultiSlotDataFeed::PutToFeedVec( total_instance * sizeof(int64_t)); } - LoD data_lod{offset}; - feed_vec_[i]->set_lod(data_lod); + // LoD data_lod{offset}; + // feed_vec_[i]->set_lod(data_lod); + if (!use_slots_is_dense_[i]) { + LoD data_lod{offset}; + feed_vec_[i]->set_lod(data_lod); + } if (use_slots_is_dense_[i]) { if (inductive_shape_index_[i] != -1) { use_slots_shape_[i][inductive_shape_index_[i]] = @@ -1101,8 +1105,12 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t)); } auto& slot_offset = offset[i]; - LoD data_lod{slot_offset}; - feed_vec_[i]->set_lod(data_lod); + // LoD data_lod{slot_offset}; + // feed_vec_[i]->set_lod(data_lod); + if (!use_slots_is_dense_[i]) { + LoD data_lod{slot_offset}; + feed_vec_[i]->set_lod(data_lod); + } if (use_slots_is_dense_[i]) { if (inductive_shape_index_[i] != -1) { use_slots_shape_[i][inductive_shape_index_[i]] = diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 4e640b09675..7c7dd1dfea9 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -149,6 +149,7 @@ class DeviceWorker { bool use_cvm_; bool no_cvm_; bool scale_sparse_gradient_with_batch_size_; + float dump_prob_; std::vector all_param_; }; @@ -226,6 +227,7 @@ class DownpourWorker : public HogwildWorker { void CopySparseTable(); void CopyDenseTable(); void CopyDenseVars(); + void DumpParam(std::ostringstream& os); virtual void DumpParam(const int batch_id); DownpourWorkerParameter param_; diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index d16699babae..9df8e2c534e 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -78,6 +78,11 @@ void DistMultiTrainer::DumpWork(int tid) { std::string path = string::format_string( "%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid); + if (user_define_dump_filename_ != "") { + path = string::format_string("%s/part-%s", dump_fields_path_.c_str(), + user_define_dump_filename_.c_str()); + } + std::shared_ptr fp = fs_open_write(path, &err_no, dump_converter_); while (1) { std::string out_str; diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index 9e04cca0935..cf3274ae5ba 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -78,6 +78,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { use_cvm_ = desc.use_cvm(); // for sparse value accessor, embedding only no_cvm_ = desc.no_cvm(); + dump_prob_ = desc.dump_prob(); scale_sparse_gradient_with_batch_size_ = desc.scale_sparse_gradient_with_batch_size(); scale_datanorm_ = desc.scale_datanorm(); @@ -131,6 +132,25 @@ void DownpourWorker::SetNeedDump(bool need_dump_field) { need_dump_field_ = need_dump_field; } +void DownpourWorker::DumpParam(std::ostringstream& os) { + for (auto& param : dump_param_) { + Variable* var = thread_scope_->FindVar(param); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + int64_t len = tensor->numel(); + std::string tensor_str; + try { + tensor_str = PrintLodTensor(tensor, 0, len); + } catch (std::exception& e) { + LOG(WARNING) << "catch exception, param:" << param; + continue; + } + os << "\t" << param << ":" << len << tensor_str; + } +} + void DownpourWorker::DumpParam(const int batch_id) { std::ostringstream os; for (auto& param : dump_param_) { @@ -963,43 +983,47 @@ void DownpourWorker::TrainFiles() { } if (need_dump_field_) { size_t batch_size = device_reader_->GetCurBatchSize(); - std::vector ars(batch_size); + std::vector ars(batch_size); for (auto& ar : ars) { ar.clear(); } auto& ins_id_vec = device_reader_->GetInsIdVec(); auto& ins_content_vec = device_reader_->GetInsContentVec(); for (size_t i = 0; i < ins_id_vec.size(); i++) { - ars[i] += ins_id_vec[i]; - ars[i] = ars[i] + "\t" + ins_content_vec[i]; - } - for (auto& field : dump_fields_) { - Variable* var = thread_scope_->FindVar(field); - if (var == nullptr) { - continue; - } - LoDTensor* tensor = var->GetMutable(); - if (!CheckValidOutput(tensor, batch_size)) { + srand((unsigned)time(NULL)); + float random_prob = (float)rand() / RAND_MAX; // NOLINT + if (random_prob >= dump_prob_) { continue; } - for (size_t i = 0; i < batch_size; ++i) { + ars[i] << ins_id_vec[i]; + ars[i] << "\t" << ins_content_vec[i]; + + for (auto& field : dump_fields_) { + Variable* var = thread_scope_->FindVar(field); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + if (!CheckValidOutput(tensor, batch_size)) { + continue; + } auto output_dim = tensor->dims()[1]; std::string output_dimstr = boost::lexical_cast(output_dim); - ars[i] = ars[i] + "\t" + field + ":" + output_dimstr; + ars[i] << "\t" << field << ":" << output_dimstr; auto bound = GetTensorBound(tensor, i); - ars[i] += PrintLodTensor(tensor, bound.first, bound.second); + ars[i] << PrintLodTensor(tensor, bound.first, bound.second); } - } - // #pragma omp parallel for - for (size_t i = 0; i < ars.size(); i++) { - if (ars[i].length() == 0) { + + if (need_dump_param_ && thread_id_ == 0) { + DumpParam(ars[i]); + } + + if (ars[i].str().length() < 2) { continue; } - writer_ << ars[i]; - } - if (need_dump_param_ && thread_id_ == 0) { - DumpParam(batch_cnt); + + writer_ << ars[i].str(); } } diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index f18edcc7fff..0cc9c96f0ea 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -52,6 +52,7 @@ message TrainerDesc { optional string user_define_dump_filename = 24; optional bool scale_sparse_gradient_with_batch_size = 25 [ default = true ]; + optional float dump_prob = 26 [ default = 1.0 ]; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py index 363475b3013..e4cbc580444 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='paddle', syntax='proto2', serialized_pb=_b( - '\n\x08ps.proto\x12\x06paddle\"\xb5\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12\x15\n\x0binit_gflags\x18\x04 \x01(\t:\x00\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x03(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc0\x02\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\'\n\x19\x65nable_sparse_table_cache\x18\x07 \x01(\x08:\x04true\x12(\n\x17sparse_table_cache_rate\x18\x08 \x01(\x01:\x07\x30.00055\x12\'\n\x1bsparse_table_cache_file_num\x18\t \x01(\r:\x02\x31\x36\"\xc1\x04\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x13\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r:\x02\x31\x31\x12\x15\n\nembedx_dim\x18\x05 \x01(\r:\x01\x38\x12\x1c\n\x10\x65mbedx_threshold\x18\x06 \x01(\r:\x02\x31\x30\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\x12\x44\n\x16sparse_commonsgd_param\x18\t \x01(\x0b\x32$.paddle.SparseCommonSGDRuleParameter\x12=\n\x0f\x65mbed_sgd_param\x18\n \x01(\x0b\x32$.paddle.SparseCommonSGDRuleParameter\x12>\n\x10\x65mbedx_sgd_param\x18\x0b \x01(\x0b\x32$.paddle.SparseCommonSGDRuleParameter\"\xba\x02\n\x1e\x44ownpourTableAccessorParameter\x12\x19\n\x0cnonclk_coeff\x18\x01 \x01(\x02:\x03\x30.1\x12\x16\n\x0b\x63lick_coeff\x18\x02 \x01(\x02:\x01\x31\x12\x1b\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02:\x03\x31.5\x12\x1d\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02:\x04\x30.25\x12\x1b\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02:\x02\x31\x36\x12#\n\x15show_click_decay_rate\x18\x06 \x01(\x02:\x04\x30.98\x12\x1d\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02:\x03\x30.8\x12$\n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02:\x02\x33\x30\x12\"\n\x17ssd_unseenday_threshold\x18\t \x01(\x05:\x01\x31\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"\x85\x01\n\x16SparseSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xc6\x01\n\x1cSparseCommonSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x32\n\x05naive\x18\x02 \x01(\x0b\x32#.paddle.SparseNaiveSGDRuleParameter\x12\x36\n\x07\x61\x64\x61grad\x18\x03 \x01(\x0b\x32%.paddle.SparseAdagradSGDRuleParameter\x12,\n\x04\x61\x64\x61m\x18\x04 \x01(\x0b\x32\x1e.paddle.SparseAdamSGDParameter\"p\n\x1bSparseNaiveSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x1d\n\rinitial_range\x18\x02 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x03 \x03(\x02\"\x8c\x01\n\x1dSparseAdagradSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xc8\x01\n\x16SparseAdamSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x30.001\x12\x1d\n\rinitial_range\x18\x02 \x01(\x01:\x06\x30.0001\x12\x1d\n\x10\x62\x65ta1_decay_rate\x18\x03 \x01(\x01:\x03\x30.9\x12\x1f\n\x10\x62\x65ta2_decay_rate\x18\x04 \x01(\x01:\x05\x30.999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x05 \x01(\x01:\x05\x31\x65-08\x12\x15\n\rweight_bounds\x18\x06 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\xac\x01\n\x10\x41\x64\x61mSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x35\x65-06\x12 \n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01:\x08\x30.999993\x12\x1e\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01:\x06\x30.9999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01:\x05\x31\x65-08\x12\x1c\n\x0emom_decay_rate\x18\x05 \x01(\x01:\x04\x30.99\"J\n\x11NaiveSGDParameter\x12\x1d\n\rlearning_rate\x18\x01 \x01(\x01:\x06\x30.0002\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xba\x04\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x12\x1b\n\x17PS_SAVE_ONE_CACHE_TABLE\x10\r\x12\x1a\n\x16PS_GET_CACHE_THRESHOLD\x10\x0e\x12\x14\n\x10PS_CACHE_SHUFFLE\x10\x0f\x12\x11\n\rPS_COPY_TABLE\x10\x10\x12\x1c\n\x18PS_COPY_TABLE_BY_FEASIGN\x10\x11\x12(\n$PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY\x10\x12\x12(\n$PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY\x10\x13\x12\x17\n\x13PS_PRINT_TABLE_STAT\x10\x14\x12\x0e\n\nPS_S2S_MSG\x10\x65\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x06\x80\x01\x01\xf8\x01\x01' + '\n\x08ps.proto\x12\x06paddle\"\xb5\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12\x15\n\x0binit_gflags\x18\x04 \x01(\t:\x00\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x03(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xfd\x02\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\'\n\x19\x65nable_sparse_table_cache\x18\x07 \x01(\x08:\x04true\x12(\n\x17sparse_table_cache_rate\x18\x08 \x01(\x01:\x07\x30.00055\x12\'\n\x1bsparse_table_cache_file_num\x18\t \x01(\r:\x02\x31\x36\x12\x1c\n\renable_revert\x18\n \x01(\x08:\x05\x66\x61lse\x12\x1d\n\x10shard_merge_rate\x18\x0b \x01(\x02:\x03\x30.2\"\x86\x05\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x13\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r:\x02\x31\x31\x12\x15\n\nembedx_dim\x18\x05 \x01(\r:\x01\x38\x12\x1c\n\x10\x65mbedx_threshold\x18\x06 \x01(\r:\x02\x31\x30\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\x12\x44\n\x16sparse_commonsgd_param\x18\t \x01(\x0b\x32$.paddle.SparseCommonSGDRuleParameter\x12=\n\x0f\x65mbed_sgd_param\x18\n \x01(\x0b\x32$.paddle.SparseCommonSGDRuleParameter\x12>\n\x10\x65mbedx_sgd_param\x18\x0b \x01(\x0b\x32$.paddle.SparseCommonSGDRuleParameter\x12\x43\n\x16hit_interval_sgd_param\x18\x0c \x01(\x0b\x32#.paddle.HitIntervalSGDRuleParameter\"\xba\x02\n\x1e\x44ownpourTableAccessorParameter\x12\x19\n\x0cnonclk_coeff\x18\x01 \x01(\x02:\x03\x30.1\x12\x16\n\x0b\x63lick_coeff\x18\x02 \x01(\x02:\x01\x31\x12\x1b\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02:\x03\x31.5\x12\x1d\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02:\x04\x30.25\x12\x1b\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02:\x02\x31\x36\x12#\n\x15show_click_decay_rate\x18\x06 \x01(\x02:\x04\x30.98\x12\x1d\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02:\x03\x30.8\x12$\n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02:\x02\x33\x30\x12\"\n\x17ssd_unseenday_threshold\x18\t \x01(\x05:\x01\x31\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"\x85\x01\n\x16SparseSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"T\n\x1bHitIntervalSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_value\x18\x02 \x01(\x01:\x01\x31\"\xc6\x01\n\x1cSparseCommonSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x32\n\x05naive\x18\x02 \x01(\x0b\x32#.paddle.SparseNaiveSGDRuleParameter\x12\x36\n\x07\x61\x64\x61grad\x18\x03 \x01(\x0b\x32%.paddle.SparseAdagradSGDRuleParameter\x12,\n\x04\x61\x64\x61m\x18\x04 \x01(\x0b\x32\x1e.paddle.SparseAdamSGDParameter\"p\n\x1bSparseNaiveSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x1d\n\rinitial_range\x18\x02 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x03 \x03(\x02\"\x8c\x01\n\x1dSparseAdagradSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xc8\x01\n\x16SparseAdamSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x30.001\x12\x1d\n\rinitial_range\x18\x02 \x01(\x01:\x06\x30.0001\x12\x1d\n\x10\x62\x65ta1_decay_rate\x18\x03 \x01(\x01:\x03\x30.9\x12\x1f\n\x10\x62\x65ta2_decay_rate\x18\x04 \x01(\x01:\x05\x30.999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x05 \x01(\x01:\x05\x31\x65-08\x12\x15\n\rweight_bounds\x18\x06 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\xac\x01\n\x10\x41\x64\x61mSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x35\x65-06\x12 \n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01:\x08\x30.999993\x12\x1e\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01:\x06\x30.9999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01:\x05\x31\x65-08\x12\x1c\n\x0emom_decay_rate\x18\x05 \x01(\x01:\x04\x30.99\"J\n\x11NaiveSGDParameter\x12\x1d\n\rlearning_rate\x18\x01 \x01(\x01:\x06\x30.0002\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xc3\x05\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x12\x1b\n\x17PS_SAVE_ONE_CACHE_TABLE\x10\r\x12\x1a\n\x16PS_GET_CACHE_THRESHOLD\x10\x0e\x12\x14\n\x10PS_CACHE_SHUFFLE\x10\x0f\x12\x11\n\rPS_COPY_TABLE\x10\x10\x12\x1c\n\x18PS_COPY_TABLE_BY_FEASIGN\x10\x11\x12(\n$PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY\x10\x12\x12(\n$PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY\x10\x13\x12\x17\n\x13PS_PRINT_TABLE_STAT\x10\x14\x12\x1c\n\x18PS_SAVE_ONE_TABLE_PREFIX\x10\x15\x12$\n PS_SAVE_ONE_TABLE_WITH_WHITELIST\x10\x16\x12$\n PS_LOAD_ONE_TABLE_WITH_WHITELIST\x10\x17\x12\r\n\tPS_REVERT\x10\x18\x12\x0e\n\nPS_CONFIRM\x10\x19\x12\x0e\n\nPS_S2S_MSG\x10\x65\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x06\x80\x01\x01\xf8\x01\x01' )) _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -49,8 +49,8 @@ _TABLETYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=4679, - serialized_end=4731, ) + serialized_start=4895, + serialized_end=4947, ) _sym_db.RegisterEnumDescriptor(_TABLETYPE) TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE) @@ -176,12 +176,34 @@ _PSCMDID = _descriptor.EnumDescriptor( options=None, type=None), _descriptor.EnumValueDescriptor( - name='PS_S2S_MSG', index=21, number=101, options=None, type=None), + name='PS_SAVE_ONE_TABLE_PREFIX', + index=21, + number=21, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PS_SAVE_ONE_TABLE_WITH_WHITELIST', + index=22, + number=22, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PS_LOAD_ONE_TABLE_WITH_WHITELIST', + index=23, + number=23, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PS_REVERT', index=24, number=24, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='PS_CONFIRM', index=25, number=25, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='PS_S2S_MSG', index=26, number=101, options=None, type=None), ], containing_type=None, options=None, - serialized_start=4734, - serialized_end=5304, ) + serialized_start=4950, + serialized_end=5657, ) _sym_db.RegisterEnumDescriptor(_PSCMDID) PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID) @@ -208,6 +230,11 @@ PS_COPY_TABLE_BY_FEASIGN = 17 PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18 PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19 PS_PRINT_TABLE_STAT = 20 +PS_SAVE_ONE_TABLE_PREFIX = 21 +PS_SAVE_ONE_TABLE_WITH_WHITELIST = 22 +PS_LOAD_ONE_TABLE_WITH_WHITELIST = 23 +PS_REVERT = 24 +PS_CONFIRM = 25 PS_S2S_MSG = 101 _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor( @@ -223,8 +250,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=4647, - serialized_end=4677, ) + serialized_start=4863, + serialized_end=4893, ) _sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE) _PSPARAMETER = _descriptor.Descriptor( @@ -1176,6 +1203,38 @@ _TABLEPARAMETER = _descriptor.Descriptor( is_extension=False, extension_scope=None, options=None), + _descriptor.FieldDescriptor( + name='enable_revert', + full_name='paddle.TableParameter.enable_revert', + index=9, + number=10, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='shard_merge_rate', + full_name='paddle.TableParameter.shard_merge_rate', + index=10, + number=11, + type=2, + cpp_type=6, + label=1, + has_default_value=True, + default_value=float(0.2), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), ], extensions=[], nested_types=[], @@ -1186,7 +1245,7 @@ _TABLEPARAMETER = _descriptor.Descriptor( extension_ranges=[], oneofs=[], serialized_start=1596, - serialized_end=1916, ) + serialized_end=1977, ) _TABLEACCESSORPARAMETER = _descriptor.Descriptor( name='TableAccessorParameter', @@ -1371,6 +1430,22 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor( is_extension=False, extension_scope=None, options=None), + _descriptor.FieldDescriptor( + name='hit_interval_sgd_param', + full_name='paddle.TableAccessorParameter.hit_interval_sgd_param', + index=11, + number=12, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), ], extensions=[], nested_types=[], @@ -1380,8 +1455,8 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=1919, - serialized_end=2496, ) + serialized_start=1980, + serialized_end=2626, ) _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( name='DownpourTableAccessorParameter', @@ -1543,8 +1618,8 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2499, - serialized_end=2813, ) + serialized_start=2629, + serialized_end=2943, ) _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( name='TableAccessorSaveParameter', @@ -1610,8 +1685,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2815, - serialized_end=2898, ) + serialized_start=2945, + serialized_end=3028, ) _PSREQUESTMESSAGE = _descriptor.Descriptor( name='PsRequestMessage', @@ -1709,8 +1784,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2900, - serialized_end=3001, ) + serialized_start=3030, + serialized_end=3131, ) _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( name='SparseSGDRuleParameter', @@ -1792,8 +1867,59 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3004, - serialized_end=3137, ) + serialized_start=3134, + serialized_end=3267, ) + +_HITINTERVALSGDRULEPARAMETER = _descriptor.Descriptor( + name='HitIntervalSGDRuleParameter', + full_name='paddle.HitIntervalSGDRuleParameter', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='learning_rate', + full_name='paddle.HitIntervalSGDRuleParameter.learning_rate', + index=0, + number=1, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(0.05), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='initial_value', + full_name='paddle.HitIntervalSGDRuleParameter.initial_value', + index=1, + number=2, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(1), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[], + serialized_start=3269, + serialized_end=3353, ) _SPARSECOMMONSGDRULEPARAMETER = _descriptor.Descriptor( name='SparseCommonSGDRuleParameter', @@ -1875,8 +2001,8 @@ _SPARSECOMMONSGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3140, - serialized_end=3338, ) + serialized_start=3356, + serialized_end=3554, ) _SPARSENAIVESGDRULEPARAMETER = _descriptor.Descriptor( name='SparseNaiveSGDRuleParameter', @@ -1942,8 +2068,8 @@ _SPARSENAIVESGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3340, - serialized_end=3452, ) + serialized_start=3556, + serialized_end=3668, ) _SPARSEADAGRADSGDRULEPARAMETER = _descriptor.Descriptor( name='SparseAdagradSGDRuleParameter', @@ -2025,8 +2151,8 @@ _SPARSEADAGRADSGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3455, - serialized_end=3595, ) + serialized_start=3671, + serialized_end=3811, ) _SPARSEADAMSGDPARAMETER = _descriptor.Descriptor( name='SparseAdamSGDParameter', @@ -2140,8 +2266,8 @@ _SPARSEADAMSGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3598, - serialized_end=3798, ) + serialized_start=3814, + serialized_end=4014, ) _DENSESGDRULEPARAMETER = _descriptor.Descriptor( name='DenseSGDRuleParameter', @@ -2239,8 +2365,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3801, - serialized_end=4026, ) + serialized_start=4017, + serialized_end=4242, ) _ADAMSGDPARAMETER = _descriptor.Descriptor( name='AdamSGDParameter', @@ -2338,8 +2464,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=4029, - serialized_end=4201, ) + serialized_start=4245, + serialized_end=4417, ) _NAIVESGDPARAMETER = _descriptor.Descriptor( name='NaiveSGDParameter', @@ -2389,8 +2515,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=4203, - serialized_end=4277, ) + serialized_start=4419, + serialized_end=4493, ) _SUMMARYSGDPARAMETER = _descriptor.Descriptor( name='SummarySGDParameter', @@ -2424,8 +2550,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=4279, - serialized_end=4338, ) + serialized_start=4495, + serialized_end=4554, ) _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( name='MovingAverageRuleParameter', @@ -2459,8 +2585,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=4340, - serialized_end=4386, ) + serialized_start=4556, + serialized_end=4602, ) _PSRESPONSEMESSAGE = _descriptor.Descriptor( name='PsResponseMessage', @@ -2526,8 +2652,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=4388, - serialized_end=4461, ) + serialized_start=4604, + serialized_end=4677, ) _FSCLIENTPARAMETER = _descriptor.Descriptor( name='FsClientParameter', @@ -2657,8 +2783,8 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=4464, - serialized_end=4677, ) + serialized_start=4680, + serialized_end=4893, ) _PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER _PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER @@ -2698,6 +2824,8 @@ _TABLEACCESSORPARAMETER.fields_by_name[ 'embed_sgd_param'].message_type = _SPARSECOMMONSGDRULEPARAMETER _TABLEACCESSORPARAMETER.fields_by_name[ 'embedx_sgd_param'].message_type = _SPARSECOMMONSGDRULEPARAMETER +_TABLEACCESSORPARAMETER.fields_by_name[ + 'hit_interval_sgd_param'].message_type = _HITINTERVALSGDRULEPARAMETER _SPARSECOMMONSGDRULEPARAMETER.fields_by_name[ 'naive'].message_type = _SPARSENAIVESGDRULEPARAMETER _SPARSECOMMONSGDRULEPARAMETER.fields_by_name[ @@ -2737,6 +2865,8 @@ DESCRIPTOR.message_types_by_name[ DESCRIPTOR.message_types_by_name['PsRequestMessage'] = _PSREQUESTMESSAGE DESCRIPTOR.message_types_by_name[ 'SparseSGDRuleParameter'] = _SPARSESGDRULEPARAMETER +DESCRIPTOR.message_types_by_name[ + 'HitIntervalSGDRuleParameter'] = _HITINTERVALSGDRULEPARAMETER DESCRIPTOR.message_types_by_name[ 'SparseCommonSGDRuleParameter'] = _SPARSECOMMONSGDRULEPARAMETER DESCRIPTOR.message_types_by_name[ @@ -2917,6 +3047,16 @@ SparseSGDRuleParameter = _reflection.GeneratedProtocolMessageType( )) _sym_db.RegisterMessage(SparseSGDRuleParameter) +HitIntervalSGDRuleParameter = _reflection.GeneratedProtocolMessageType( + 'HitIntervalSGDRuleParameter', + (_message.Message, ), + dict( + DESCRIPTOR=_HITINTERVALSGDRULEPARAMETER, + __module__='ps_pb2' + # @@protoc_insertion_point(class_scope:paddle.HitIntervalSGDRuleParameter) + )) +_sym_db.RegisterMessage(HitIntervalSGDRuleParameter) + SparseCommonSGDRuleParameter = _reflection.GeneratedProtocolMessageType( 'SparseCommonSGDRuleParameter', (_message.Message, ), diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 6c4621ac39b..6c2065f8145 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -131,6 +131,9 @@ class TrainerDesc(object): for loss in loss_names: self.proto_desc.loss_names.append(loss) + def _set_dump_prob(self, dump_prob): + self.proto_desc.dump_prob = dump_prob + def _set_adjust_ins_weight(self, config_dict): self.proto_desc.adjust_ins_weight_config.need_adjust = \ config_dict.get("need_adjust", False) diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index ec6d10a37b8..bcae3ea039d 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -68,10 +68,15 @@ class TrainerFactory(object): trainer._set_dump_fields_path(opt_info["dump_fields_path"]) if opt_info.get("dump_file_num") is not None: trainer._set_dump_file_num(opt_info["dump_file_num"]) + if opt_info.get("user_define_dump_filename") is not None: + trainer._set_user_define_dump_filename(opt_info[ + "user_define_dump_filename"]) if opt_info.get("dump_converter") is not None: trainer._set_dump_converter(opt_info["dump_converter"]) if opt_info.get("dump_param") is not None: trainer._set_dump_param(opt_info["dump_param"]) + if opt_info.get("dump_prob") is not None: + trainer._set_dump_prob(opt_info["dump_prob"]) if "fleet_desc" in opt_info: device_worker._set_fleet_desc(opt_info["fleet_desc"]) -- GitLab