未验证 提交 d8396281 编写于 作者: T Thunderbrook 提交者: GitHub

add slot to sparse table (#18686)

The change includes 2 things:

1. save delta model and shrink table are control by the same parameter before, now add delete_after_unseen_days to control shrink table.
2. value in sparse table has no slot before, now add slot in sparse table, and add DownpureCtrAccessor to support the new meta.
test=develop
上级 f0cfc3c3
...@@ -179,6 +179,7 @@ class DownpourWorker : public HogwildWorker { ...@@ -179,6 +179,7 @@ class DownpourWorker : public HogwildWorker {
private: private:
bool need_to_push_dense_; bool need_to_push_dense_;
bool dump_slot_;
bool need_to_push_sparse_; bool need_to_push_sparse_;
DownpourWorkerParameter param_; DownpourWorkerParameter param_;
// just save the value in param_ for easy access // just save the value in param_ for easy access
...@@ -285,7 +286,6 @@ class SectionWorker : public DeviceWorker { ...@@ -285,7 +286,6 @@ class SectionWorker : public DeviceWorker {
int section_num_; int section_num_;
int pipeline_num_; int pipeline_num_;
int thread_id_; int thread_id_;
// This worker will consume scope from in_scope_queue_ // This worker will consume scope from in_scope_queue_
// and produce scope to out_scope_queue_ // and produce scope to out_scope_queue_
ScopeQueue* in_scope_queue_ = nullptr; ScopeQueue* in_scope_queue_ = nullptr;
......
...@@ -64,6 +64,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -64,6 +64,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
fleet_ptr_ = FleetWrapper::GetInstance(); fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
use_cvm_ = desc.use_cvm(); use_cvm_ = desc.use_cvm();
dump_slot_ = desc.dump_slot();
} }
void DownpourWorker::CollectLabelInfo(size_t table_idx) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
...@@ -282,7 +283,8 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -282,7 +283,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
fleet_ptr_->PushSparseVarsWithLabelAsync( fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_); &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_);
timeline.Pause(); timeline.Pause();
push_sparse_time += timeline.ElapsedSec(); push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
...@@ -454,7 +456,8 @@ void DownpourWorker::TrainFiles() { ...@@ -454,7 +456,8 @@ void DownpourWorker::TrainFiles() {
fleet_ptr_->PushSparseVarsWithLabelAsync( fleet_ptr_->PushSparseVarsWithLabelAsync(
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_); &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_);
} }
} }
......
...@@ -288,19 +288,27 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -288,19 +288,27 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
const std::vector<std::string>& sparse_grad_names, const int emb_dim, const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status, std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm) { const int batch_size, const bool use_cvm, const bool dump_slot) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
int offset = 2; int offset = 2;
int slot_offset = 0;
int grad_dim = emb_dim; int grad_dim = emb_dim;
int show_index = 0;
int click_index = 1;
if (use_cvm) { if (use_cvm) {
offset = 0; offset = 0;
grad_dim = emb_dim - 2; grad_dim = emb_dim - 2;
} }
if (dump_slot) {
slot_offset = 1;
show_index = 1;
click_index = 2;
}
CHECK_GE(grad_dim, 0); CHECK_GE(grad_dim, 0);
push_values->resize(fea_keys.size() + 1); push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) { for (auto& t : *push_values) {
t.resize(emb_dim + offset); t.resize(emb_dim + offset + slot_offset);
} }
uint64_t fea_idx = 0u; uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) { for (size_t i = 0; i < sparse_key_names.size(); ++i) {
...@@ -315,7 +323,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -315,7 +323,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
} }
int len = tensor->numel(); int len = tensor->numel();
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
int slot = 0;
if (dump_slot) {
slot = boost::lexical_cast<int>(sparse_key_names[i]);
}
Variable* g_var = scope.FindVar(sparse_grad_names[i]); Variable* g_var = scope.FindVar(sparse_grad_names[i]);
CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found"; CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found";
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>(); LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
...@@ -339,14 +350,19 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -339,14 +350,19 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
} }
CHECK(fea_idx < (*push_values).size()); CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size()); CHECK(fea_idx < fea_labels.size());
if (use_cvm) { if (use_cvm) {
memcpy((*push_values)[fea_idx].data() + offset, g, memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
} else { } else {
memcpy((*push_values)[fea_idx].data() + offset, g, memcpy((*push_values)[fea_idx].data() + offset + slot_offset, g,
sizeof(float) * emb_dim); sizeof(float) * emb_dim);
(*push_values)[fea_idx][0] = 1.0f; (*push_values)[fea_idx][show_index] = 1.0f;
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]); (*push_values)[fea_idx][click_index] =
static_cast<float>(fea_labels[fea_idx]);
}
if (dump_slot) {
(*push_values)[fea_idx][0] = static_cast<float>(slot);
} }
g += emb_dim; g += emb_dim;
fea_idx++; fea_idx++;
......
...@@ -100,7 +100,7 @@ class FleetWrapper { ...@@ -100,7 +100,7 @@ class FleetWrapper {
const std::vector<std::string>& sparse_grad_names, const int emb_dim, const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status, std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm); const int batch_size, const bool use_cvm, const bool dump_slot);
// Push sparse variables to server in Async mode // Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names // Param<In>: scope, table_id, fea_keys, sparse_grad_names
......
...@@ -33,6 +33,7 @@ message TrainerDesc { ...@@ -33,6 +33,7 @@ message TrainerDesc {
optional bool debug = 6 [ default = false ]; optional bool debug = 6 [ default = false ];
optional FetchConfig fetch_config = 7; optional FetchConfig fetch_config = 7;
optional bool use_cvm = 8 [ default = false ]; optional bool use_cvm = 8 [ default = false ];
optional bool dump_slot = 9 [ default = false ];
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
......
...@@ -75,7 +75,7 @@ class DownpourServer(Server): ...@@ -75,7 +75,7 @@ class DownpourServer(Server):
table.type = pslib.PS_SPARSE_TABLE table.type = pslib.PS_SPARSE_TABLE
table.compress_in_save = True table.compress_in_save = True
table.shard_num = 1000 table.shard_num = 1000
table.accessor.accessor_class = "DownpourFeatureValueAccessor" table.accessor.accessor_class = "DownpourCtrAccessor"
table.accessor.sparse_sgd_param.learning_rate = learning_rate table.accessor.sparse_sgd_param.learning_rate = learning_rate
table.accessor.sparse_sgd_param.initial_g2sum = 3 table.accessor.sparse_sgd_param.initial_g2sum = 3
table.accessor.sparse_sgd_param.initial_range = 1e-4 table.accessor.sparse_sgd_param.initial_range = 1e-4
...@@ -88,7 +88,8 @@ class DownpourServer(Server): ...@@ -88,7 +88,8 @@ class DownpourServer(Server):
table.accessor.downpour_accessor_param.click_coeff = 2 table.accessor.downpour_accessor_param.click_coeff = 2
table.accessor.downpour_accessor_param.base_threshold = 0.2 table.accessor.downpour_accessor_param.base_threshold = 0.2
table.accessor.downpour_accessor_param.delta_threshold = 0.15 table.accessor.downpour_accessor_param.delta_threshold = 0.15
table.accessor.downpour_accessor_param.delta_keep_days = 31 table.accessor.downpour_accessor_param.delta_keep_days = 16
table.accessor.downpour_accessor_param.delete_after_unseen_days = 30
table.accessor.downpour_accessor_param.show_click_decay_rate = 0.999 table.accessor.downpour_accessor_param.show_click_decay_rate = 0.999
table.accessor.downpour_accessor_param.delete_threshold = 0.8 table.accessor.downpour_accessor_param.delete_threshold = 0.8
......
...@@ -162,6 +162,10 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -162,6 +162,10 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["fleet_desc"] = ps_param opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops opt_info["worker_skipped_ops"] = worker_skipped_ops
opt_info["use_cvm"] = strategy.get("use_cvm", False) opt_info["use_cvm"] = strategy.get("use_cvm", False)
opt_info["dump_slot"] = False
if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class == "DownpourCtrAccessor":
opt_info["dump_slot"] = True
for loss in losses: for loss in losses:
loss.block.program._fleet_opt = opt_info loss.block.program._fleet_opt = opt_info
......
...@@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle', package='paddle',
syntax='proto2', syntax='proto2',
serialized_pb=_b( serialized_pb=_b(
'\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc4\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xce\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01' '\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc4\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xf0\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\x12 \n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01'
)) ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -49,8 +49,8 @@ _TABLETYPE = _descriptor.EnumDescriptor( ...@@ -49,8 +49,8 @@ _TABLETYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=3494, serialized_start=3528,
serialized_end=3546, ) serialized_end=3580, )
_sym_db.RegisterEnumDescriptor(_TABLETYPE) _sym_db.RegisterEnumDescriptor(_TABLETYPE)
TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE) TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE)
...@@ -134,8 +134,8 @@ _PSCMDID = _descriptor.EnumDescriptor( ...@@ -134,8 +134,8 @@ _PSCMDID = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=3549, serialized_start=3583,
serialized_end=3866, ) serialized_end=3900, )
_sym_db.RegisterEnumDescriptor(_PSCMDID) _sym_db.RegisterEnumDescriptor(_PSCMDID)
PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID) PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID)
...@@ -168,8 +168,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor( ...@@ -168,8 +168,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=3462, serialized_start=3496,
serialized_end=3492, ) serialized_end=3526, )
_sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE) _sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE)
_PSPARAMETER = _descriptor.Descriptor( _PSPARAMETER = _descriptor.Descriptor(
...@@ -1335,6 +1335,22 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( ...@@ -1335,6 +1335,22 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor(
is_extension=False, is_extension=False,
extension_scope=None, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor(
name='delete_after_unseen_days',
full_name='paddle.DownpourTableAccessorParameter.delete_after_unseen_days',
index=7,
number=8,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
...@@ -1345,7 +1361,7 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( ...@@ -1345,7 +1361,7 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2144, serialized_start=2144,
serialized_end=2350, ) serialized_end=2384, )
_TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor(
name='TableAccessorSaveParameter', name='TableAccessorSaveParameter',
...@@ -1411,8 +1427,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( ...@@ -1411,8 +1427,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2352, serialized_start=2386,
serialized_end=2435, ) serialized_end=2469, )
_PSREQUESTMESSAGE = _descriptor.Descriptor( _PSREQUESTMESSAGE = _descriptor.Descriptor(
name='PsRequestMessage', name='PsRequestMessage',
...@@ -1510,8 +1526,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor( ...@@ -1510,8 +1526,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2437, serialized_start=2471,
serialized_end=2538, ) serialized_end=2572, )
_SPARSESGDRULEPARAMETER = _descriptor.Descriptor( _SPARSESGDRULEPARAMETER = _descriptor.Descriptor(
name='SparseSGDRuleParameter', name='SparseSGDRuleParameter',
...@@ -1593,8 +1609,8 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( ...@@ -1593,8 +1609,8 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2540, serialized_start=2574,
serialized_end=2659, ) serialized_end=2693, )
_DENSESGDRULEPARAMETER = _descriptor.Descriptor( _DENSESGDRULEPARAMETER = _descriptor.Descriptor(
name='DenseSGDRuleParameter', name='DenseSGDRuleParameter',
...@@ -1692,8 +1708,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor( ...@@ -1692,8 +1708,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2662, serialized_start=2696,
serialized_end=2887, ) serialized_end=2921, )
_ADAMSGDPARAMETER = _descriptor.Descriptor( _ADAMSGDPARAMETER = _descriptor.Descriptor(
name='AdamSGDParameter', name='AdamSGDParameter',
...@@ -1791,8 +1807,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor( ...@@ -1791,8 +1807,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2890, serialized_start=2924,
serialized_end=3024, ) serialized_end=3058, )
_NAIVESGDPARAMETER = _descriptor.Descriptor( _NAIVESGDPARAMETER = _descriptor.Descriptor(
name='NaiveSGDParameter', name='NaiveSGDParameter',
...@@ -1842,8 +1858,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor( ...@@ -1842,8 +1858,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=3026, serialized_start=3060,
serialized_end=3092, ) serialized_end=3126, )
_SUMMARYSGDPARAMETER = _descriptor.Descriptor( _SUMMARYSGDPARAMETER = _descriptor.Descriptor(
name='SummarySGDParameter', name='SummarySGDParameter',
...@@ -1877,8 +1893,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor( ...@@ -1877,8 +1893,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=3094, serialized_start=3128,
serialized_end=3153, ) serialized_end=3187, )
_MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor(
name='MovingAverageRuleParameter', name='MovingAverageRuleParameter',
...@@ -1912,8 +1928,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( ...@@ -1912,8 +1928,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=3155, serialized_start=3189,
serialized_end=3201, ) serialized_end=3235, )
_PSRESPONSEMESSAGE = _descriptor.Descriptor( _PSRESPONSEMESSAGE = _descriptor.Descriptor(
name='PsResponseMessage', name='PsResponseMessage',
...@@ -1979,8 +1995,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor( ...@@ -1979,8 +1995,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=3203, serialized_start=3237,
serialized_end=3276, ) serialized_end=3310, )
_FSCLIENTPARAMETER = _descriptor.Descriptor( _FSCLIENTPARAMETER = _descriptor.Descriptor(
name='FsClientParameter', name='FsClientParameter',
...@@ -2110,8 +2126,8 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor( ...@@ -2110,8 +2126,8 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor(
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=3279, serialized_start=3313,
serialized_end=3492, ) serialized_end=3526, )
_PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER _PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER
_PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER _PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER
......
...@@ -17,8 +17,12 @@ from os import path ...@@ -17,8 +17,12 @@ from os import path
__all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer'] __all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer']
# can be initialized from train_desc,
class TrainerDesc(object): class TrainerDesc(object):
'''
Set proto from python to c++.
Can be initialized from train_desc.
'''
def __init__(self): def __init__(self):
''' '''
self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc = data_feed_pb2.DataFeedDesc()
...@@ -71,6 +75,9 @@ class TrainerDesc(object): ...@@ -71,6 +75,9 @@ class TrainerDesc(object):
def _set_use_cvm(self, use_cvm=False): def _set_use_cvm(self, use_cvm=False):
self.proto_desc.use_cvm = use_cvm self.proto_desc.use_cvm = use_cvm
def _set_dump_slot(self, dump_slot):
self.proto_desc.dump_slot = dump_slot
def _desc(self): def _desc(self):
from google.protobuf import text_format from google.protobuf import text_format
return self.proto_desc.SerializeToString() return self.proto_desc.SerializeToString()
...@@ -81,6 +88,11 @@ class TrainerDesc(object): ...@@ -81,6 +88,11 @@ class TrainerDesc(object):
class MultiTrainer(TrainerDesc): class MultiTrainer(TrainerDesc):
'''
Implement of MultiTrainer.
Can be init from TrainerDesc.
'''
def __init__(self): def __init__(self):
super(MultiTrainer, self).__init__() super(MultiTrainer, self).__init__()
pass pass
......
...@@ -39,5 +39,6 @@ class TrainerFactory(object): ...@@ -39,5 +39,6 @@ class TrainerFactory(object):
device_worker._set_fleet_desc(opt_info["fleet_desc"]) device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_fleet_desc(opt_info["fleet_desc"]) trainer._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_use_cvm(opt_info["use_cvm"]) trainer._set_use_cvm(opt_info["use_cvm"])
trainer._set_dump_slot(opt_info["dump_slot"])
trainer._set_device_worker(device_worker) trainer._set_device_worker(device_worker)
return trainer return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册