From 8de4d31a5b59fea6071eb5c842dd09341fc11234 Mon Sep 17 00:00:00 2001 From: heqiaozhi Date: Thu, 7 Mar 2019 11:16:42 +0800 Subject: [PATCH] refactor async exe --- python/paddle/fluid/async_executor.py | 12 +- python/paddle/fluid/distributed/downpour.py | 86 +++++--- python/paddle/fluid/distributed/node.py | 24 +++ python/paddle/fluid/distributed/ps_pb2.py | 216 ++++++++++++++++---- 4 files changed, 269 insertions(+), 69 deletions(-) diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index 61de5ade86..e0e36fa2ee 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -121,7 +121,9 @@ class AsyncExecutor(object): with open("trainer_desc.proto", "w") as fout: fout.write(trainer._desc()) # define a trainer and a device_worker here - self.executor.run_from_files(program_desc, trainer._desc(), debug) + self.executor.run_from_files(program_desc, + trainer._desc(), debug, + str(id(program_desc))) ''' def run(self, @@ -194,7 +196,7 @@ class AsyncExecutor(object): self.executor.run_from_files(program_desc, data_feed.desc(), filelist, thread_num, - fetch_var_names, mode, debug) + fetch_var_names, mode, debug, str(id(program_desc))) ''' def download_data(self, @@ -313,7 +315,11 @@ class AsyncExecutor(object): self.dist_desc = dist_desc place = core.CPUPlace() executor = Executor(place) - executor.run(startup_program) + if isinstance(startup_program, list): + for sp in startup_program: + executor.run(sp) + else: + executor.run(startup_program) self.instance.barrier_all() #wait all server start ips = self.instance.gather_ips() diff --git a/python/paddle/fluid/distributed/downpour.py b/python/paddle/fluid/distributed/downpour.py index 87dfab92c5..9edb631351 100644 --- a/python/paddle/fluid/distributed/downpour.py +++ b/python/paddle/fluid/distributed/downpour.py @@ -43,9 +43,13 @@ class DownpourSGD(object): self.learning_rate_ = learning_rate self.window_ = window self.type = "downpour" + self.data_norm_name = [ + ".batch_size", ".batch_square_sum", ".batch_sum", + ".batch_size@GRAD", ".batch_square_sum@GRAD", ".batch_sum@GRAD" + ] def minimize(self, - loss, + losses, startup_program=None, parameter_list=None, no_grad_set=None): @@ -65,39 +69,75 @@ class DownpourSGD(object): worker_skipped_ops: operator names that need to be skipped during execution """ - params_grads = sorted( - append_backward(loss, parameter_list, no_grad_set), - key=lambda x: x[0].name) - table_name = find_distributed_lookup_table(loss.block.program) + if not isinstance(losses, list): + raise ValueError('losses is a list, just lick [model.cost]') + table_name = find_distributed_lookup_table(losses[0].block.program) prefetch_slots = find_distributed_lookup_table_inputs( - loss.block.program, table_name) + losses[0].block.program, table_name) prefetch_slots_emb = find_distributed_lookup_table_outputs( - loss.block.program, table_name) + losses[0].block.program, table_name) + + ps_param = pslib.PSParameter() server = DownpourServer() - # window is communication strategy worker = DownpourWorker(self.window_) - # Todo(guru4elephant): support multiple tables definitions - # currently support one big sparse table sparse_table_index = 0 - # currently merge all dense parameters into one dense table - dense_table_index = 1 - params = [] - grads = [] - for i in params_grads: - params.append(i[0]) - for i in params_grads: - grads.append(i[1]) server.add_sparse_table(sparse_table_index, self.learning_rate_, prefetch_slots, prefetch_slots_emb) - server.add_dense_table(dense_table_index, self.learning_rate_, params, - grads) worker.add_sparse_table(sparse_table_index, self.learning_rate_, prefetch_slots, prefetch_slots_emb) - worker.add_dense_table(dense_table_index, self.learning_rate_, params, - grads) - ps_param = pslib.PSParameter() + dense_table_index = 1 + program_configs = [] + for loss_index in range(len(losses)): + program_config = ps_param.trainer_param.program_config.add() + program_config.program_id = str( + id(losses[loss_index].block.program)) + program_config.pull_sparse_table_id.extend([sparse_table_index]) + program_config.push_sparse_table_id.extend([sparse_table_index]) + params_grads = sorted( + append_backward(losses[loss_index], parameter_list, + no_grad_set), + key=lambda x: x[0].name) + params = [] + grads = [] + data_norm_params = [] + data_norm_grads = [] + for i in params_grads: + is_data_norm_data = False + for data_norm_name in self.data_norm_name: + if i[0].name.endswith(data_norm_name): + is_data_norm_data = True + data_norm_params.append(i[0]) + if not is_data_norm_data: + params.append(i[0]) + for i in params_grads: + is_data_norm_data = False + for data_norm_grad in self.data_norm_name: + if i[0].name.endswith(data_norm_grad): + is_data_norm_data = True + data_norm_grads.append(i[1]) + if not is_data_norm_data: + grads.append(i[1]) + server.add_dense_table(dense_table_index, self.learning_rate_, + params, grads) + worker.add_dense_table(dense_table_index, self.learning_rate_, + params, grads) + program_config.pull_dense_table_id.extend([dense_table_index]) + program_config.push_dense_table_id.extend([dense_table_index]) + if len(data_norm_params) != 0 and len(data_norm_grads) != 0: + dense_table_index += 1 + server.add_data_norm_table(dense_table_index, + self.learning_rate_, + data_norm_params, data_norm_grads) + worker.add_dense_table(dense_table_index, self.learning_rate_, + data_norm_params, data_norm_grads) + program_config.pull_dense_table_id.extend([dense_table_index]) + program_config.push_dense_table_id.extend([dense_table_index]) + dense_table_index += 1 + program_configs.append(program_config) ps_param.server_param.CopyFrom(server.get_desc()) ps_param.trainer_param.CopyFrom(worker.get_desc()) + for program_config in program_configs: + ps_param.trainer_param.program_config.extend([program_config]) # Todo(guru4elephant): figure out how to support more sparse parameters # currently only support lookup_table worker_skipped_ops = ["lookup_table", "lookup_table_grad"] diff --git a/python/paddle/fluid/distributed/node.py b/python/paddle/fluid/distributed/node.py index 41e0d64e0b..60035b6e8d 100644 --- a/python/paddle/fluid/distributed/node.py +++ b/python/paddle/fluid/distributed/node.py @@ -112,6 +112,30 @@ class DownpourServer(Server): fea_dim += reduce(lambda x, y: x * y, param.shape, 1) table.accessor.fea_dim = fea_dim + def add_data_norm_table(self, table_id, learning_rate, param_var, grad_var): + """ + Args: + table_id(int): id of sparse params table + learning_rate(float): the learning rate used to update parameters. \ + Can be a float value + param_var(list): all dense param. it is a list. + grad_var(list): all dense grad parm it is a list. + Returns: + return None + """ + table = self.server_.downpour_server_param.downpour_table_param.add() + table.table_id = table_id + table.table_class = "DownpourDenseTable" + table.type = pslib.PS_DENSE_TABLE + table.accessor.accessor_class = "DownpourDenseValueAccessor" + table.accessor.dense_sgd_param.name = "summary" + table.accessor.dense_sgd_param.summary.summary_decay_rate = 0.999999 + fea_dim = 0 + for param in filter(lambda x: x.name.find("embedding") == -1, + param_var): + fea_dim += reduce(lambda x, y: x * y, param.shape, 1) + table.accessor.fea_dim = fea_dim + def get_desc(self): """ Return downpour server program_desc diff --git a/python/paddle/fluid/distributed/ps_pb2.py b/python/paddle/fluid/distributed/ps_pb2.py index 0d226c4d59..5c9b2def07 100644 --- a/python/paddle/fluid/distributed/ps_pb2.py +++ b/python/paddle/fluid/distributed/ps_pb2.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -10,6 +10,8 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +# limitations under the License. + # Generated by the protocol buffer compiler. DO NOT EDIT! # source: ps.proto @@ -30,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='paddle', syntax='proto2', serialized_pb=_b( - '\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xce\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xbf\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x12\n\nshared_num\x18\x03 \x01(\x04\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xce\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01' + '\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xbf\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x12\n\nshared_num\x18\x03 \x01(\x04\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xce\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01' )) _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -47,8 +49,8 @@ _TABLETYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=3286, - serialized_end=3338, ) + serialized_start=3489, + serialized_end=3541, ) _sym_db.RegisterEnumDescriptor(_TABLETYPE) TableType = enum_type_wrapper.EnumTypeWrapper(_TABLETYPE) @@ -132,8 +134,8 @@ _PSCMDID = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=3341, - serialized_end=3658, ) + serialized_start=3544, + serialized_end=3861, ) _sym_db.RegisterEnumDescriptor(_PSCMDID) PsCmdID = enum_type_wrapper.EnumTypeWrapper(_PSCMDID) @@ -166,8 +168,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=3254, - serialized_end=3284, ) + serialized_start=3457, + serialized_end=3487, ) _sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE) _PSPARAMETER = _descriptor.Descriptor( @@ -493,6 +495,22 @@ _DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor( is_extension=False, extension_scope=None, options=None), + _descriptor.FieldDescriptor( + name='program_config', + full_name='paddle.DownpourTrainerParameter.program_config', + index=5, + number=6, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), ], extensions=[], nested_types=[], @@ -503,7 +521,106 @@ _DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor( extension_ranges=[], oneofs=[], serialized_start=557, - serialized_end=763, ) + serialized_end=810, ) + +_PROGRAMCONFIG = _descriptor.Descriptor( + name='ProgramConfig', + full_name='paddle.ProgramConfig', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='program_id', + full_name='paddle.ProgramConfig.program_id', + index=0, + number=1, + type=9, + cpp_type=9, + label=2, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='push_sparse_table_id', + full_name='paddle.ProgramConfig.push_sparse_table_id', + index=1, + number=2, + type=5, + cpp_type=1, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='push_dense_table_id', + full_name='paddle.ProgramConfig.push_dense_table_id', + index=2, + number=3, + type=5, + cpp_type=1, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='pull_sparse_table_id', + full_name='paddle.ProgramConfig.pull_sparse_table_id', + index=3, + number=4, + type=5, + cpp_type=1, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='pull_dense_table_id', + full_name='paddle.ProgramConfig.pull_dense_table_id', + index=4, + number=5, + type=5, + cpp_type=1, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[], + serialized_start=813, + serialized_end=966, ) _DENSETABLEPARAMETER = _descriptor.Descriptor( name='DenseTableParameter', @@ -585,8 +702,8 @@ _DENSETABLEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=765, - serialized_end=888, ) + serialized_start=968, + serialized_end=1091, ) _SPARSETABLEPARAMETER = _descriptor.Descriptor( name='SparseTableParameter', @@ -684,8 +801,8 @@ _SPARSETABLEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=890, - serialized_end=1012, ) + serialized_start=1093, + serialized_end=1215, ) _DOWNPOURSERVERPARAMETER = _descriptor.Descriptor( name='DownpourServerParameter', @@ -735,8 +852,8 @@ _DOWNPOURSERVERPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=1015, - serialized_end=1149, ) + serialized_start=1218, + serialized_end=1352, ) _SERVERSERVICEPARAMETER = _descriptor.Descriptor( name='ServerServiceParameter', @@ -834,8 +951,8 @@ _SERVERSERVICEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=1152, - serialized_end=1367, ) + serialized_start=1355, + serialized_end=1570, ) _TABLEPARAMETER = _descriptor.Descriptor( name='TableParameter', @@ -949,8 +1066,8 @@ _TABLEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=1370, - serialized_end=1561, ) + serialized_start=1573, + serialized_end=1764, ) _TABLEACCESSORPARAMETER = _descriptor.Descriptor( name='TableAccessorParameter', @@ -1096,8 +1213,8 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=1564, - serialized_end=1933, ) + serialized_start=1767, + serialized_end=2136, ) _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( name='DownpourTableAccessorParameter', @@ -1227,8 +1344,8 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=1936, - serialized_end=2142, ) + serialized_start=2139, + serialized_end=2345, ) _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( name='TableAccessorSaveParameter', @@ -1294,8 +1411,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2144, - serialized_end=2227, ) + serialized_start=2347, + serialized_end=2430, ) _PSREQUESTMESSAGE = _descriptor.Descriptor( name='PsRequestMessage', @@ -1393,8 +1510,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2229, - serialized_end=2330, ) + serialized_start=2432, + serialized_end=2533, ) _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( name='SparseSGDRuleParameter', @@ -1476,8 +1593,8 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2332, - serialized_end=2451, ) + serialized_start=2535, + serialized_end=2654, ) _DENSESGDRULEPARAMETER = _descriptor.Descriptor( name='DenseSGDRuleParameter', @@ -1575,8 +1692,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2454, - serialized_end=2679, ) + serialized_start=2657, + serialized_end=2882, ) _ADAMSGDPARAMETER = _descriptor.Descriptor( name='AdamSGDParameter', @@ -1674,8 +1791,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2682, - serialized_end=2816, ) + serialized_start=2885, + serialized_end=3019, ) _NAIVESGDPARAMETER = _descriptor.Descriptor( name='NaiveSGDParameter', @@ -1725,8 +1842,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2818, - serialized_end=2884, ) + serialized_start=3021, + serialized_end=3087, ) _SUMMARYSGDPARAMETER = _descriptor.Descriptor( name='SummarySGDParameter', @@ -1760,8 +1877,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2886, - serialized_end=2945, ) + serialized_start=3089, + serialized_end=3148, ) _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( name='MovingAverageRuleParameter', @@ -1795,8 +1912,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2947, - serialized_end=2993, ) + serialized_start=3150, + serialized_end=3196, ) _PSRESPONSEMESSAGE = _descriptor.Descriptor( name='PsResponseMessage', @@ -1862,8 +1979,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=2995, - serialized_end=3068, ) + serialized_start=3198, + serialized_end=3271, ) _FSCLIENTPARAMETER = _descriptor.Descriptor( name='FsClientParameter', @@ -1993,8 +2110,8 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor( syntax='proto2', extension_ranges=[], oneofs=[], - serialized_start=3071, - serialized_end=3284, ) + serialized_start=3274, + serialized_end=3487, ) _PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER _PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER @@ -2011,6 +2128,8 @@ _DOWNPOURTRAINERPARAMETER.fields_by_name[ 'dense_table'].message_type = _DENSETABLEPARAMETER _DOWNPOURTRAINERPARAMETER.fields_by_name[ 'sparse_table'].message_type = _SPARSETABLEPARAMETER +_DOWNPOURTRAINERPARAMETER.fields_by_name[ + 'program_config'].message_type = _PROGRAMCONFIG _DOWNPOURSERVERPARAMETER.fields_by_name[ 'downpour_table_param'].message_type = _TABLEPARAMETER _DOWNPOURSERVERPARAMETER.fields_by_name[ @@ -2042,6 +2161,7 @@ DESCRIPTOR.message_types_by_name[ 'DownpourWorkerParameter'] = _DOWNPOURWORKERPARAMETER DESCRIPTOR.message_types_by_name[ 'DownpourTrainerParameter'] = _DOWNPOURTRAINERPARAMETER +DESCRIPTOR.message_types_by_name['ProgramConfig'] = _PROGRAMCONFIG DESCRIPTOR.message_types_by_name['DenseTableParameter'] = _DENSETABLEPARAMETER DESCRIPTOR.message_types_by_name['SparseTableParameter'] = _SPARSETABLEPARAMETER DESCRIPTOR.message_types_by_name[ @@ -2120,6 +2240,16 @@ DownpourTrainerParameter = _reflection.GeneratedProtocolMessageType( )) _sym_db.RegisterMessage(DownpourTrainerParameter) +ProgramConfig = _reflection.GeneratedProtocolMessageType( + 'ProgramConfig', + (_message.Message, ), + dict( + DESCRIPTOR=_PROGRAMCONFIG, + __module__='ps_pb2' + # @@protoc_insertion_point(class_scope:paddle.ProgramConfig) + )) +_sym_db.RegisterMessage(ProgramConfig) + DenseTableParameter = _reflection.GeneratedProtocolMessageType( 'DenseTableParameter', (_message.Message, ), -- GitLab