From 349e82d66936d007c616e4cba1394b92a29023b9 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Wed, 20 Nov 2019 15:20:49 +0800 Subject: [PATCH] support general embedding params (#21217) * general table * add sparse table test=develop * no cvm test=develop * add no_cvm test=develop * add note test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * add key of optimizer test=develop --- paddle/fluid/framework/device_worker.h | 6 +- paddle/fluid/framework/downpour_worker.cc | 11 +- paddle/fluid/framework/fleet/fleet_wrapper.cc | 10 +- paddle/fluid/framework/fleet/fleet_wrapper.h | 2 +- paddle/fluid/framework/trainer_desc.proto | 1 + python/paddle/fluid/device_worker.py | 3 +- .../fleet/parameter_server/pslib/node.py | 72 ++- .../pslib/optimizer_factory.py | 1 + .../fleet/parameter_server/pslib/ps_pb2.py | 506 ++++++++++++++++-- python/paddle/fluid/trainer_desc.py | 3 + python/paddle/fluid/trainer_factory.py | 2 + 11 files changed, 575 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 737df3dd6b..2ee09a7b83 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -105,7 +105,10 @@ class PullDenseWorker { // should incorporate different type of device class DeviceWorker { public: - DeviceWorker() { use_cvm_ = false; } + DeviceWorker() { + no_cvm_ = true; + use_cvm_ = false; + } virtual ~DeviceWorker() {} virtual void Initialize(const TrainerDesc& desc) = 0; virtual void SetDeviceIndex(int tid) = 0; @@ -135,6 +138,7 @@ class DeviceWorker { int64_t batch_num_; FetchConfig fetch_config_; bool use_cvm_; + bool no_cvm_; }; class CPUWorkerBase : public DeviceWorker { diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index cedf22bd9f..60dbe3c2f9 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -75,6 +75,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { fleet_ptr_ = FleetWrapper::GetInstance(); fetch_config_ = desc.fetch_config(); use_cvm_ = desc.use_cvm(); + // for sparse value accessor, embedding only + no_cvm_ = desc.no_cvm(); scale_datanorm_ = desc.scale_datanorm(); dump_slot_ = desc.dump_slot(); dump_fields_.resize(desc.dump_fields_size()); @@ -211,6 +213,9 @@ void DownpourWorker::DumpParam() { } void DownpourWorker::CollectLabelInfo(size_t table_idx) { + if (no_cvm_) { + return; + } uint64_t table_id = static_cast( param_.program_config(0).pull_sparse_table_id(table_idx)); @@ -312,7 +317,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { int nid_ins_index = 0; for (int index = 0; index < len; ++index) { - if (use_cvm_) { + if (use_cvm_ || no_cvm_) { if (ids[index] == 0u) { memcpy(ptr + table.emb_dim() * index, init_value.data(), sizeof(float) * table.emb_dim()); @@ -681,7 +686,7 @@ void DownpourWorker::TrainFilesWithProfiler() { *thread_scope_, tid, features_[tid], feature_labels_[tid], sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_, - dump_slot_, &sparse_push_keys_[tid]); + dump_slot_, &sparse_push_keys_[tid], no_cvm_); timeline.Pause(); push_sparse_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec(); @@ -906,7 +911,7 @@ void DownpourWorker::TrainFiles() { *thread_scope_, tid, features_[tid], feature_labels_[tid], sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_, - dump_slot_, &sparse_push_keys_[tid]); + dump_slot_, &sparse_push_keys_[tid], no_cvm_); } } diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index bef9c85e6d..d46a7ec1fc 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -303,7 +303,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( std::vector>* push_values, std::vector<::std::future>* push_sparse_status, const int batch_size, const bool use_cvm, const bool dump_slot, - std::vector* sparse_push_keys) { + std::vector* sparse_push_keys, const bool no_cvm) { #ifdef PADDLE_WITH_PSLIB int offset = 2; int slot_offset = 0; @@ -314,6 +314,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( offset = 0; grad_dim = emb_dim - 2; } + if (no_cvm) { + offset = 0; + grad_dim = emb_dim; + } if (dump_slot) { slot_offset = 1; show_index = 1; @@ -370,12 +374,12 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( } sparse_push_keys->push_back(ids[id_idx]); CHECK(fea_idx < (*push_values).size()); - CHECK(fea_idx < fea_labels.size()); - if (use_cvm) { + if (use_cvm || no_cvm) { memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g, sizeof(float) * emb_dim); } else { + CHECK(fea_idx < fea_labels.size()); memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g, sizeof(float) * emb_dim); (*push_values)[fea_idx][show_index] = 1.0f; diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 73247748b0..fc98cba853 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -124,7 +124,7 @@ class FleetWrapper { std::vector>* push_values, std::vector<::std::future>* push_sparse_status, const int batch_size, const bool use_cvm, const bool dump_slot, - std::vector* sparse_push_keys); + std::vector* sparse_push_keys, const bool no_cvm); // Push sparse variables to server in Async mode // Param: scope, table_id, fea_keys, sparse_grad_names diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 5212c09b65..2d42220fd4 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -46,6 +46,7 @@ message TrainerDesc { optional CopyTableConfig copy_table_config = 19; // adjust ins weight optional AdjustInsWeightConfig adjust_ins_weight_config = 20; + optional bool no_cvm = 21 [ default = false ]; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index db2c15a1d6..2700f006db 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -160,7 +160,8 @@ class DownpourSGD(DeviceWorker): .sparse_table[i].slot_value) sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[ i].slot_gradient) - if opt_info["use_cvm"]: + if opt_info["use_cvm"] or "no_cvm" in opt_info and opt_info[ + "no_cvm"] == True: sparse_table.emb_dim = \ self._fleet_desc.server_param.downpour_server_param.downpour_table_param[ i].accessor.fea_dim diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py index a283c01853..5afcf0cf2e 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py @@ -80,7 +80,8 @@ class DownpourServer(Server): 'sparse_click_coeff', 'sparse_base_threshold', 'sparse_delta_threshold', 'sparse_delta_keep_days', \ 'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \ 'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \ - 'sparse_cache_file_num'] + 'sparse_cache_file_num', 'sparse_beta1_decay_rate', 'sparse_beta2_decay_rate', \ + 'sparse_ada_epsilon', 'sparse_optimizer'] for key in strategy: if key not in support_sparse_key_list: @@ -108,9 +109,13 @@ class DownpourServer(Server): table.compress_in_save = strategy.get('sparse_compress_in_save', True) table.shard_num = strategy.get('sparse_shard_num', 1000) + # DownpourFeatureValueAccessor: for ctr task, has cvm, embedding and sgd info + # DownpourCtrAccessor : for ctr task, has cvm, slot, embedding and sgd info + # DownpourSparseValueAccessor : for general task, has embedding and sgd info support_accessor_class = [ - 'DownpourFeatureValueAccessor', 'DownpourCtrAccessor' + 'DownpourFeatureValueAccessor', 'DownpourCtrAccessor', + 'DownpourSparseValueAccessor' ] if strategy.get('sparse_accessor_class') is not None: accessor_class = strategy.get('sparse_accessor_class') @@ -169,6 +174,69 @@ class DownpourServer(Server): table1.converter = converter table1.deconverter = deconverter + table2 = table.accessor.table_accessor_save_param.add() + table2.param = 2 + table2.converter = converter + table2.deconverter = deconverter + elif accessor_class == 'DownpourSparseValueAccessor': + optimizer_name = strategy.get("sparse_optimizer", "adam") + table.accessor.sparse_commonsgd_param.name = optimizer_name + table.accessor.embedx_dim = strategy.get('sparse_embedx_dim', 8) + table.accessor.fea_dim = int(table.accessor.embedx_dim) + if optimizer_name == "naive": + table.accessor.sparse_commonsgd_param.naive.learning_rate = \ + strategy.get('sparse_learning_rate', 0.05) + table.accessor.sparse_commonsgd_param.naive.initial_range = \ + strategy.get('sparse_initial_range', 1e-4) + if strategy.get('sparse_weight_bounds') is None: + table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend( + [-10, 10]) + else: + table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend( + strategy.get('sparse_weight_bounds')) + elif optimizer_name == "adagrad": + table.accessor.sparse_commonsgd_param.adagrad.learning_rate = \ + strategy.get('sparse_learning_rate', 0.05) + table.accessor.sparse_commonsgd_param.adagrad.initial_range = \ + strategy.get('sparse_initial_range', 1e-4) + table.accessor.sparse_commonsgd_param.adagrad.initial_g2sum = strategy.get( + 'sparse_initial_g2sum', 3) + if strategy.get('sparse_weight_bounds') is None: + table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend( + [-10, 10]) + else: + table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend( + strategy.get('sparse_weight_bounds')) + elif optimizer_name == "adam": + table.accessor.sparse_commonsgd_param.adam.learning_rate = \ + strategy.get('sparse_learning_rate', 0.001) + table.accessor.sparse_commonsgd_param.adam.initial_range = \ + strategy.get('sparse_initial_range', 1e-4) + table.accessor.sparse_commonsgd_param.adam.beta1_decay_rate = strategy.get( + 'sparse_beta1_decay_rate', 0.9) + table.accessor.sparse_commonsgd_param.adam.beta2_decay_rate = strategy.get( + 'sparse_beta2_decay_rate', 0.999) + table.accessor.sparse_commonsgd_param.adam.ada_epsilon = strategy.get( + 'sparse_ada_epsilon', 1e-8) + if strategy.get('sparse_weight_bounds') is None: + table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend( + [-10, 10]) + else: + table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend( + strategy.get('sparse_weight_bounds')) + converter = strategy.get( + 'sparse_converter', + "(scripts/xbox_compressor_mf.py | bin/xbox_pb_converter)") + deconverter = strategy.get( + 'sparse_deconverter', + "(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)" + ) + + table1 = table.accessor.table_accessor_save_param.add() + table1.param = 1 + table1.converter = converter + table1.deconverter = deconverter + table2 = table.accessor.table_accessor_save_param.add() table2.param = 2 table2.converter = converter diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 89225c37fe..b16d3a71df 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -367,6 +367,7 @@ class DistributedAdam(DistributedOptimizerImplBase): opt_info["fleet_desc"] = ps_param opt_info["worker_skipped_ops"] = worker_skipped_ops opt_info["use_cvm"] = strategy.get("use_cvm", False) + opt_info["no_cvm"] = strategy.get("no_cvm", False) opt_info["stat_var_names"] = strategy.get("stat_var_names", []) opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1) opt_info["check_nan_var_names"] = strategy.get("check_nan_var_names", diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py index 0021b61094..eec3b0716d 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='paddle', syntax='proto2', serialized_pb=_b( - '\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x03(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc0\x02\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\'\n\x19\x65nable_sparse_table_cache\x18\x07 \x01(\x08:\x04true\x12(\n\x17sparse_table_cache_rate\x18\x08 \x01(\x01:\x07\x30.00055\x12\'\n\x1bsparse_table_cache_file_num\x18\t \x01(\r:\x02\x31\x36\"\xfc\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x13\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r:\x02\x31\x31\x12\x15\n\nembedx_dim\x18\x05 \x01(\r:\x01\x38\x12\x1c\n\x10\x65mbedx_threshold\x18\x06 \x01(\r:\x02\x31\x30\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\x96\x02\n\x1e\x44ownpourTableAccessorParameter\x12\x19\n\x0cnonclk_coeff\x18\x01 \x01(\x02:\x03\x30.1\x12\x16\n\x0b\x63lick_coeff\x18\x02 \x01(\x02:\x01\x31\x12\x1b\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02:\x03\x31.5\x12\x1d\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02:\x04\x30.25\x12\x1b\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02:\x02\x31\x36\x12#\n\x15show_click_decay_rate\x18\x06 \x01(\x02:\x04\x30.98\x12\x1d\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02:\x03\x30.8\x12$\n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02:\x02\x33\x30\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"\x85\x01\n\x16SparseSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\xac\x01\n\x10\x41\x64\x61mSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x35\x65-06\x12 \n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01:\x08\x30.999993\x12\x1e\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01:\x06\x30.9999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01:\x05\x31\x65-08\x12\x1c\n\x0emom_decay_rate\x18\x05 \x01(\x01:\x04\x30.99\"J\n\x11NaiveSGDParameter\x12\x1d\n\rlearning_rate\x18\x01 \x01(\x01:\x06\x30.0002\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\x9c\x03\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x12\x1b\n\x17PS_SAVE_ONE_CACHE_TABLE\x10\r\x12\x1a\n\x16PS_GET_CACHE_THRESHOLD\x10\x0e\x12\x14\n\x10PS_CACHE_SHUFFLE\x10\x0f\x12\x0e\n\nPS_S2S_MSG\x10\x65\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01' + '\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x03(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc0\x02\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\'\n\x19\x65nable_sparse_table_cache\x18\x07 \x01(\x08:\x04true\x12(\n\x17sparse_table_cache_rate\x18\x08 \x01(\x01:\x07\x30.00055\x12\'\n\x1bsparse_table_cache_file_num\x18\t \x01(\r:\x02\x31\x36\"\xc2\x03\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x13\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r:\x02\x31\x31\x12\x15\n\nembedx_dim\x18\x05 \x01(\r:\x01\x38\x12\x1c\n\x10\x65mbedx_threshold\x18\x06 \x01(\r:\x02\x31\x30\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\x12\x44\n\x16sparse_commonsgd_param\x18\t \x01(\x0b\x32$.paddle.SparseCommonSGDRuleParameter\"\x96\x02\n\x1e\x44ownpourTableAccessorParameter\x12\x19\n\x0cnonclk_coeff\x18\x01 \x01(\x02:\x03\x30.1\x12\x16\n\x0b\x63lick_coeff\x18\x02 \x01(\x02:\x01\x31\x12\x1b\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02:\x03\x31.5\x12\x1d\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02:\x04\x30.25\x12\x1b\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02:\x02\x31\x36\x12#\n\x15show_click_decay_rate\x18\x06 \x01(\x02:\x04\x30.98\x12\x1d\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02:\x03\x30.8\x12$\n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02:\x02\x33\x30\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"\x85\x01\n\x16SparseSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xc6\x01\n\x1cSparseCommonSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x32\n\x05naive\x18\x02 \x01(\x0b\x32#.paddle.SparseNaiveSGDRuleParameter\x12\x36\n\x07\x61\x64\x61grad\x18\x03 \x01(\x0b\x32%.paddle.SparseAdagradSGDRuleParameter\x12,\n\x04\x61\x64\x61m\x18\x04 \x01(\x0b\x32\x1e.paddle.SparseAdamSGDParameter\"p\n\x1bSparseNaiveSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x1d\n\rinitial_range\x18\x02 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x03 \x03(\x02\"\x8c\x01\n\x1dSparseAdagradSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xc8\x01\n\x16SparseAdamSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x30.001\x12\x1d\n\rinitial_range\x18\x02 \x01(\x01:\x06\x30.0001\x12\x1d\n\x10\x62\x65ta1_decay_rate\x18\x03 \x01(\x01:\x03\x30.9\x12\x1f\n\x10\x62\x65ta2_decay_rate\x18\x04 \x01(\x01:\x05\x30.999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x05 \x01(\x01:\x05\x31\x65-08\x12\x15\n\rweight_bounds\x18\x06 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\xac\x01\n\x10\x41\x64\x61mSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x35\x65-06\x12 \n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01:\x08\x30.999993\x12\x1e\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01:\x06\x30.9999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01:\x05\x31\x65-08\x12\x1c\n\x0emom_decay_rate\x18\x05 \x01(\x01:\x04\x30.99\"J\n\x11NaiveSGDParameter\x12\x1d\n\rlearning_rate\x18\x01 \x01(\x01:\x06\x30.0002\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xa1\x04\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x12\x1b\n\x17PS_SAVE_ONE_CACHE_TABLE\x10\r\x12\x1a\n\x16PS_GET_CACHE_THRESHOLD\x10\x0e\x12\x14\n\x10PS_CACHE_SHUFFLE\x10\x0f\x12\x11\n\rPS_COPY_TABLE\x10\x10\x12\x1c\n\x18PS_COPY_TABLE_BY_FEASIGN\x10\x11\x12(\n$PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY\x10\x12\x12(\n$PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY\x10\x13\x12\x0e\n\nPS_S2S_MSG\x10\x65\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01' )) _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -49,8 +49,8 @@ _TABLETYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=3762, - serialized_end=3814, ) + serialized_start=4493, + serialized_end=4545, ) _sym_db.RegisterEnumDescriptor(_TABLETYPE) TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE) @@ -150,12 +150,32 @@ _PSCMDID = _descriptor.EnumDescriptor( options=None, type=None), _descriptor.EnumValueDescriptor( - name='PS_S2S_MSG', index=16, number=101, options=None, type=None), + name='PS_COPY_TABLE', index=16, number=16, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='PS_COPY_TABLE_BY_FEASIGN', + index=17, + number=17, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY', + index=18, + number=18, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY', + index=19, + number=19, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PS_S2S_MSG', index=20, number=101, options=None, type=None), ], containing_type=None, options=None, - serialized_start=3817, - serialized_end=4229, ) + serialized_start=4548, + serialized_end=5093, ) _sym_db.RegisterEnumDescriptor(_PSCMDID) PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID) @@ -177,6 +197,10 @@ PS_STOP_SERVER = 12 PS_SAVE_ONE_CACHE_TABLE = 13 PS_GET_CACHE_THRESHOLD = 14 PS_CACHE_SHUFFLE = 15 +PS_COPY_TABLE = 16 +PS_COPY_TABLE_BY_FEASIGN = 17 +PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18 +PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19 PS_S2S_MSG = 101 _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor( @@ -192,8 +216,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=3730, - serialized_end=3760, ) + serialized_start=4461, + serialized_end=4491, ) _sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE) _PSPARAMETER = _descriptor.Descriptor( @@ -1276,6 +1300,22 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor( is_extension=False, extension_scope=None, options=None), + _descriptor.FieldDescriptor( + name='sparse_commonsgd_param', + full_name='paddle.TableAccessorParameter.sparse_commonsgd_param', + index=8, + number=9, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), ], extensions=[], nested_types=[], @@ -1286,7 +1326,7 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor( extension_ranges=[], oneofs=[], serialized_start=1896, - serialized_end=2276, ) + serialized_end=2346, ) _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( name='DownpourTableAccessorParameter', @@ -1432,8 +1472,8 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2279, - serialized_end=2557, ) + serialized_start=2349, + serialized_end=2627, ) _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( name='TableAccessorSaveParameter', @@ -1499,8 +1539,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2559, - serialized_end=2642, ) + serialized_start=2629, + serialized_end=2712, ) _PSREQUESTMESSAGE = _descriptor.Descriptor( name='PsRequestMessage', @@ -1598,8 +1638,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2644, - serialized_end=2745, ) + serialized_start=2714, + serialized_end=2815, ) _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( name='SparseSGDRuleParameter', @@ -1681,8 +1721,356 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2748, - serialized_end=2881, ) + serialized_start=2818, + serialized_end=2951, ) + +_SPARSECOMMONSGDRULEPARAMETER = _descriptor.Descriptor( + name='SparseCommonSGDRuleParameter', + full_name='paddle.SparseCommonSGDRuleParameter', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', + full_name='paddle.SparseCommonSGDRuleParameter.name', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='naive', + full_name='paddle.SparseCommonSGDRuleParameter.naive', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='adagrad', + full_name='paddle.SparseCommonSGDRuleParameter.adagrad', + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='adam', + full_name='paddle.SparseCommonSGDRuleParameter.adam', + index=3, + number=4, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[], + serialized_start=2954, + serialized_end=3152, ) + +_SPARSENAIVESGDRULEPARAMETER = _descriptor.Descriptor( + name='SparseNaiveSGDRuleParameter', + full_name='paddle.SparseNaiveSGDRuleParameter', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='learning_rate', + full_name='paddle.SparseNaiveSGDRuleParameter.learning_rate', + index=0, + number=1, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(0.05), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='initial_range', + full_name='paddle.SparseNaiveSGDRuleParameter.initial_range', + index=1, + number=2, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(0.0001), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='weight_bounds', + full_name='paddle.SparseNaiveSGDRuleParameter.weight_bounds', + index=2, + number=3, + type=2, + cpp_type=6, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[], + serialized_start=3154, + serialized_end=3266, ) + +_SPARSEADAGRADSGDRULEPARAMETER = _descriptor.Descriptor( + name='SparseAdagradSGDRuleParameter', + full_name='paddle.SparseAdagradSGDRuleParameter', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='learning_rate', + full_name='paddle.SparseAdagradSGDRuleParameter.learning_rate', + index=0, + number=1, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(0.05), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='initial_g2sum', + full_name='paddle.SparseAdagradSGDRuleParameter.initial_g2sum', + index=1, + number=2, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(3), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='initial_range', + full_name='paddle.SparseAdagradSGDRuleParameter.initial_range', + index=2, + number=3, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(0.0001), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='weight_bounds', + full_name='paddle.SparseAdagradSGDRuleParameter.weight_bounds', + index=3, + number=4, + type=2, + cpp_type=6, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[], + serialized_start=3269, + serialized_end=3409, ) + +_SPARSEADAMSGDPARAMETER = _descriptor.Descriptor( + name='SparseAdamSGDParameter', + full_name='paddle.SparseAdamSGDParameter', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='learning_rate', + full_name='paddle.SparseAdamSGDParameter.learning_rate', + index=0, + number=1, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(0.001), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='initial_range', + full_name='paddle.SparseAdamSGDParameter.initial_range', + index=1, + number=2, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(0.0001), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='beta1_decay_rate', + full_name='paddle.SparseAdamSGDParameter.beta1_decay_rate', + index=2, + number=3, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(0.9), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='beta2_decay_rate', + full_name='paddle.SparseAdamSGDParameter.beta2_decay_rate', + index=3, + number=4, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(0.999), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='ada_epsilon', + full_name='paddle.SparseAdamSGDParameter.ada_epsilon', + index=4, + number=5, + type=1, + cpp_type=5, + label=1, + has_default_value=True, + default_value=float(1e-08), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='weight_bounds', + full_name='paddle.SparseAdamSGDParameter.weight_bounds', + index=5, + number=6, + type=2, + cpp_type=6, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[], + serialized_start=3412, + serialized_end=3612, ) _DENSESGDRULEPARAMETER = _descriptor.Descriptor( name='DenseSGDRuleParameter', @@ -1780,8 +2168,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2884, - serialized_end=3109, ) + serialized_start=3615, + serialized_end=3840, ) _ADAMSGDPARAMETER = _descriptor.Descriptor( name='AdamSGDParameter', @@ -1879,8 +2267,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3112, - serialized_end=3284, ) + serialized_start=3843, + serialized_end=4015, ) _NAIVESGDPARAMETER = _descriptor.Descriptor( name='NaiveSGDParameter', @@ -1930,8 +2318,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3286, - serialized_end=3360, ) + serialized_start=4017, + serialized_end=4091, ) _SUMMARYSGDPARAMETER = _descriptor.Descriptor( name='SummarySGDParameter', @@ -1965,8 +2353,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3362, - serialized_end=3421, ) + serialized_start=4093, + serialized_end=4152, ) _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( name='MovingAverageRuleParameter', @@ -2000,8 +2388,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3423, - serialized_end=3469, ) + serialized_start=4154, + serialized_end=4200, ) _PSRESPONSEMESSAGE = _descriptor.Descriptor( name='PsResponseMessage', @@ -2067,8 +2455,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3471, - serialized_end=3544, ) + serialized_start=4202, + serialized_end=4275, ) _FSCLIENTPARAMETER = _descriptor.Descriptor( name='FsClientParameter', @@ -2198,8 +2586,8 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3547, - serialized_end=3760, ) + serialized_start=4278, + serialized_end=4491, ) _PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER _PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER @@ -2233,6 +2621,14 @@ _TABLEACCESSORPARAMETER.fields_by_name[ 'downpour_accessor_param'].message_type = _DOWNPOURTABLEACCESSORPARAMETER _TABLEACCESSORPARAMETER.fields_by_name[ 'table_accessor_save_param'].message_type = _TABLEACCESSORSAVEPARAMETER +_TABLEACCESSORPARAMETER.fields_by_name[ + 'sparse_commonsgd_param'].message_type = _SPARSECOMMONSGDRULEPARAMETER +_SPARSECOMMONSGDRULEPARAMETER.fields_by_name[ + 'naive'].message_type = _SPARSENAIVESGDRULEPARAMETER +_SPARSECOMMONSGDRULEPARAMETER.fields_by_name[ + 'adagrad'].message_type = _SPARSEADAGRADSGDRULEPARAMETER +_SPARSECOMMONSGDRULEPARAMETER.fields_by_name[ + 'adam'].message_type = _SPARSEADAMSGDPARAMETER _DENSESGDRULEPARAMETER.fields_by_name['adam'].message_type = _ADAMSGDPARAMETER _DENSESGDRULEPARAMETER.fields_by_name['naive'].message_type = _NAIVESGDPARAMETER _DENSESGDRULEPARAMETER.fields_by_name[ @@ -2266,6 +2662,14 @@ DESCRIPTOR.message_types_by_name[ DESCRIPTOR.message_types_by_name['PsRequestMessage'] = _PSREQUESTMESSAGE DESCRIPTOR.message_types_by_name[ 'SparseSGDRuleParameter'] = _SPARSESGDRULEPARAMETER +DESCRIPTOR.message_types_by_name[ + 'SparseCommonSGDRuleParameter'] = _SPARSECOMMONSGDRULEPARAMETER +DESCRIPTOR.message_types_by_name[ + 'SparseNaiveSGDRuleParameter'] = _SPARSENAIVESGDRULEPARAMETER +DESCRIPTOR.message_types_by_name[ + 'SparseAdagradSGDRuleParameter'] = _SPARSEADAGRADSGDRULEPARAMETER +DESCRIPTOR.message_types_by_name[ + 'SparseAdamSGDParameter'] = _SPARSEADAMSGDPARAMETER DESCRIPTOR.message_types_by_name[ 'DenseSGDRuleParameter'] = _DENSESGDRULEPARAMETER DESCRIPTOR.message_types_by_name['AdamSGDParameter'] = _ADAMSGDPARAMETER @@ -2438,6 +2842,46 @@ SparseSGDRuleParameter = _reflection.GeneratedProtocolMessageType( )) _sym_db.RegisterMessage(SparseSGDRuleParameter) +SparseCommonSGDRuleParameter = _reflection.GeneratedProtocolMessageType( + 'SparseCommonSGDRuleParameter', + (_message.Message, ), + dict( + DESCRIPTOR=_SPARSECOMMONSGDRULEPARAMETER, + __module__='ps_pb2' + # @@protoc_insertion_point(class_scope:paddle.SparseCommonSGDRuleParameter) + )) +_sym_db.RegisterMessage(SparseCommonSGDRuleParameter) + +SparseNaiveSGDRuleParameter = _reflection.GeneratedProtocolMessageType( + 'SparseNaiveSGDRuleParameter', + (_message.Message, ), + dict( + DESCRIPTOR=_SPARSENAIVESGDRULEPARAMETER, + __module__='ps_pb2' + # @@protoc_insertion_point(class_scope:paddle.SparseNaiveSGDRuleParameter) + )) +_sym_db.RegisterMessage(SparseNaiveSGDRuleParameter) + +SparseAdagradSGDRuleParameter = _reflection.GeneratedProtocolMessageType( + 'SparseAdagradSGDRuleParameter', + (_message.Message, ), + dict( + DESCRIPTOR=_SPARSEADAGRADSGDRULEPARAMETER, + __module__='ps_pb2' + # @@protoc_insertion_point(class_scope:paddle.SparseAdagradSGDRuleParameter) + )) +_sym_db.RegisterMessage(SparseAdagradSGDRuleParameter) + +SparseAdamSGDParameter = _reflection.GeneratedProtocolMessageType( + 'SparseAdamSGDParameter', + (_message.Message, ), + dict( + DESCRIPTOR=_SPARSEADAMSGDPARAMETER, + __module__='ps_pb2' + # @@protoc_insertion_point(class_scope:paddle.SparseAdamSGDParameter) + )) +_sym_db.RegisterMessage(SparseAdamSGDParameter) + DenseSGDRuleParameter = _reflection.GeneratedProtocolMessageType( 'DenseSGDRuleParameter', (_message.Message, ), diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 75a1472d09..18753daf4e 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -76,6 +76,9 @@ class TrainerDesc(object): def _set_use_cvm(self, use_cvm=False): self.proto_desc.use_cvm = use_cvm + def _set_no_cvm(self, no_cvm=False): + self.proto_desc.no_cvm = no_cvm + def _set_scale_datanorm(self, scale_datanorm=-1): self.proto_desc.scale_datanorm = scale_datanorm diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index 70154e383a..b21d3164fa 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -52,6 +52,8 @@ class TrainerFactory(object): trainer._set_fleet_desc(opt_info["fleet_desc"]) if opt_info.get("use_cvm") is not None: trainer._set_use_cvm(opt_info["use_cvm"]) + if opt_info.get("no_cvm") is not None: + trainer._set_no_cvm(opt_info["no_cvm"]) if opt_info.get("scale_datanorm") is not None: trainer._set_scale_datanorm(opt_info["scale_datanorm"]) if opt_info.get("dump_slot") is not None: -- GitLab