提交 06213b79 编写于 作者: D dongdaxiang

add hadoop helper function for distributed training

上级 49130f9b
...@@ -150,8 +150,13 @@ class AsyncExecutor(object): ...@@ -150,8 +150,13 @@ class AsyncExecutor(object):
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, debug) fetch_var_names, debug)
def config_ps(self, dist_desc, host_sign_list, node_num, index): def config_distributed_nodes(self, dist_opt):
self.executor.config_pslib(dist_desc, host_sign_list, node_num, index) # get total rank
# get rank index
# get iplists
# get hadoop info
return
def start_server(self): def start_server(self):
self.executor.start_server() self.executor.start_server()
......
...@@ -21,7 +21,8 @@ def find_distributed_lookup_table_inputs(program, table_name): ...@@ -21,7 +21,8 @@ def find_distributed_lookup_table_inputs(program, table_name):
for op in program.global_block().ops: for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE:
if table_name == op.input("W")[0]: if table_name == op.input("W")[0]:
inputs.extend([local_vars[name] for name in op.input("Ids")]) inputs.extend(
[local_vars[name] for name in op.input("Ids")])
return inputs return inputs
def find_distributed_lookup_table_outputs(program, table_name): def find_distributed_lookup_table_outputs(program, table_name):
...@@ -30,7 +31,8 @@ def find_distributed_lookup_table_outputs(program, table_name): ...@@ -30,7 +31,8 @@ def find_distributed_lookup_table_outputs(program, table_name):
for op in program.global_block().ops: for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE:
if table_name == op.input("W")[0]: if table_name == op.input("W")[0]:
outputs.extend([local_vars[name] for name in op.output("Out")]) outputs.extend(
[local_vars[name] for name in op.output("Out")])
return outputs return outputs
def find_distributed_lookup_table(program): def find_distributed_lookup_table(program):
......
...@@ -8,30 +8,57 @@ from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_o ...@@ -8,30 +8,57 @@ from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_o
from google.protobuf import text_format from google.protobuf import text_format
class DownpourSGD(object): class DownpourSGD(object):
"""
Distributed optimizer of downpour stochastic gradient descent
Standard implementation of Google's Downpour SGD
in Large Scale Distributed Deep Networks
Args:
learning_rate (float): the learning rate used to update parameters. \
Can be a float value
Examples:
.. code-block:: python
downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2)
downpour_sgd.minimize(cost)
"""
def __init__(self, learning_rate=0.001, window=1): def __init__(self, learning_rate=0.001, window=1):
# todo(guru4elephant): if optimizer is not None, will warning here # todo(guru4elephant): add more optimizers here as argument
# todo(guru4elephant): make learning_rate as a variable
self.learning_rate_ = learning_rate self.learning_rate_ = learning_rate
self.window_ = window self.window_ = window
self.type = "downpour"
def minimize(self, loss, startup_program=None, def minimize(self, loss, startup_program=None,
parameter_list=None, no_grad_set=None): parameter_list=None, no_grad_set=None):
params_grads = sorted(append_backward(loss), key=lambda x:x[0].name) params_grads = sorted(append_backward(
loss, parameter_list, no_grad_set), key=lambda x:x[0].name)
table_name = find_distributed_lookup_table(loss.block.program) table_name = find_distributed_lookup_table(loss.block.program)
prefetch_slots = find_distributed_lookup_table_inputs( prefetch_slots = find_distributed_lookup_table_inputs(
loss.block.program, table_name) loss.block.program, table_name)
prefetch_slots_emb = find_distributed_lookup_table_outputs( prefetch_slots_emb = find_distributed_lookup_table_outputs(
loss.block.program, table_name) loss.block.program, table_name)
server = DownpourServer() server = DownpourServer()
# window is communication strategy
worker = DownpourWorker(self.window_) worker = DownpourWorker(self.window_)
server.add_sparse_table(0, self.learning_rate_, # Todo(guru4elephant): support multiple tables definitions
# currently support one big sparse table
sparse_table_index = 0
# currently merge all dense parameters into one dense table
dense_table_index = 1
server.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
server.add_dense_table(1, self.learning_rate_, params_grads[0], params_grads[1]) server.add_dense_table(dense_table_index, self.learning_rate_,
worker.add_sparse_table(0, self.learning_rate_, params_grads[0], params_grads[1])
worker.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
worker.add_dense_table(1, self.learning_rate_, params_grads[0], params_grads[1]) worker.add_dense_table(dense_table_index, self.learning_rate_,
params_grads[0], params_grads[1])
ps_param = pslib.PSParameter() ps_param = pslib.PSParameter()
ps_param.server_param.CopyFrom(server.get_desc()) ps_param.server_param.CopyFrom(server.get_desc())
#ps_param.worker_param.CopyFrom(worker.get_desc()) ps_param.worker_param.CopyFrom(worker.get_desc())
# Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"] worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param_str = text_format.MessageToString(ps_param) ps_param_str = text_format.MessageToString(ps_param)
return [ps_param_str, worker_skipped_ops, text_format.MessageToString(worker.get_desc())] return [ps_param_str, worker_skipped_ops]
from mpi4py import MPI from mpi4py import MPI
class FileSystem(object):
def __init__(self, fs_type="afs",
uri="afs://tianqi.afs.baidu.com:9902",
user=None,
passwd=None,
hadoop_bin="",
afs_conf=None):
assert user not None
assert passwd not None
assert hadoop_bin not None
fs_client = pslib.FsClientParameter()
if fs_type == "afs":
fs_client.fs_type = pslib.FsApiType.AFS
else:
fs_client.fs_type = pslib.FsApiType.HDFS
fs_client.uri = uri
fs_client.user = user
fs_client.passwd = passwd
fs_client.buffer_size = 0
fs_client.afs_conf = afs_conf if not afs_conf else ""
class MPIHelper(object): class MPIHelper(object):
def __init__(self): def __init__(self):
self.comm = MPI.COMM_WORLD self.comm = MPI.COMM_WORLD
...@@ -18,3 +40,5 @@ class MPIHelper(object): ...@@ -18,3 +40,5 @@ class MPIHelper(object):
def get_hostname(self): def get_hostname(self):
import socket import socket
return socket.gethostname() return socket.gethostname()
...@@ -12,7 +12,6 @@ class Worker(object): ...@@ -12,7 +12,6 @@ class Worker(object):
class DownpourServer(Server): class DownpourServer(Server):
def __init__(self): def __init__(self):
#self.server_ = pslib.ServerParameter().downpour_server_param
self.server_ = pslib.ServerParameter() self.server_ = pslib.ServerParameter()
def add_sparse_table(self, table_id, learning_rate, def add_sparse_table(self, table_id, learning_rate,
......
...@@ -20,7 +20,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -20,7 +20,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
name='ps.proto', name='ps.proto',
package='paddle', package='paddle',
syntax='proto2', syntax='proto2',
serialized_pb=_b('\n\x08ps.proto\x12\x06paddle\"\xe4\x01\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xbc\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x02 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x03 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1c\n\x14pull_dense_per_batch\x18\x04 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x05 \x01(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\x91\x01\n\x16ServerServiceParameter\x12\x14\n\x0cserver_class\x18\x01 \x01(\t\x12\x14\n\x0c\x63lient_class\x18\x02 \x01(\t\x12\x15\n\rservice_class\x18\x03 \x01(\t\x12\x19\n\x11start_server_port\x18\x04 \x01(\r\x12\x19\n\x11server_thread_num\x18\x05 \x01(\r\"\xbf\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x12\n\nshared_num\x18\x03 \x01(\x04\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xce\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01') serialized_pb=_b('\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xbc\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1c\n\x14pull_dense_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\x91\x01\n\x16ServerServiceParameter\x12\x14\n\x0cserver_class\x18\x01 \x01(\t\x12\x14\n\x0c\x63lient_class\x18\x02 \x01(\t\x12\x15\n\rservice_class\x18\x03 \x01(\t\x12\x19\n\x11start_server_port\x18\x04 \x01(\r\x12\x19\n\x11server_thread_num\x18\x05 \x01(\r\"\xbf\x01\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x12\n\nshared_num\x18\x03 \x01(\x04\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\"\xf1\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r\x12\x12\n\nembedx_dim\x18\x05 \x01(\r\x12\x18\n\x10\x65mbedx_threshold\x18\x06 \x01(\r\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\xce\x01\n\x1e\x44ownpourTableAccessorParameter\x12\x14\n\x0cnonclk_coeff\x18\x01 \x01(\x02\x12\x13\n\x0b\x63lick_coeff\x18\x02 \x01(\x02\x12\x16\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02\x12\x17\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02\x12\x17\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02\x12\x1d\n\x15show_click_decay_rate\x18\x06 \x01(\x02\x12\x18\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"w\n\x16SparseSGDRuleParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x15\n\rinitial_g2sum\x18\x02 \x01(\x01\x12\x18\n\rinitial_range\x18\x03 \x01(\x01:\x01\x30\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\x86\x01\n\x10\x41\x64\x61mSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\x12\x16\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01\x12\x13\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01\x12\x16\n\x0emom_decay_rate\x18\x05 \x01(\x01\"B\n\x11NaiveSGDParameter\x12\x15\n\rlearning_rate\x18\x01 \x01(\x01\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\xbd\x02\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01')
) )
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -41,8 +41,8 @@ _TABLETYPE = _descriptor.EnumDescriptor( ...@@ -41,8 +41,8 @@ _TABLETYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=3140, serialized_start=3198,
serialized_end=3192, serialized_end=3250,
) )
_sym_db.RegisterEnumDescriptor(_TABLETYPE) _sym_db.RegisterEnumDescriptor(_TABLETYPE)
...@@ -108,8 +108,8 @@ _PSCMDID = _descriptor.EnumDescriptor( ...@@ -108,8 +108,8 @@ _PSCMDID = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=3195, serialized_start=3253,
serialized_end=3512, serialized_end=3570,
) )
_sym_db.RegisterEnumDescriptor(_PSCMDID) _sym_db.RegisterEnumDescriptor(_PSCMDID)
...@@ -148,8 +148,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor( ...@@ -148,8 +148,8 @@ _FSCLIENTPARAMETER_FSAPITYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=3108, serialized_start=3166,
serialized_end=3138, serialized_end=3196,
) )
_sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE) _sym_db.RegisterEnumDescriptor(_FSCLIENTPARAMETER_FSAPITYPE)
...@@ -197,7 +197,14 @@ _PSPARAMETER = _descriptor.Descriptor( ...@@ -197,7 +197,14 @@ _PSPARAMETER = _descriptor.Descriptor(
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='fs_client_param', full_name='paddle.PSParameter.fs_client_param', index=5, name='trainer_param', full_name='paddle.PSParameter.trainer_param', index=5,
number=301, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='fs_client_param', full_name='paddle.PSParameter.fs_client_param', index=6,
number=501, type=11, cpp_type=10, label=1, number=501, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None, has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
...@@ -216,7 +223,7 @@ _PSPARAMETER = _descriptor.Descriptor( ...@@ -216,7 +223,7 @@ _PSPARAMETER = _descriptor.Descriptor(
oneofs=[ oneofs=[
], ],
serialized_start=21, serialized_start=21,
serialized_end=249, serialized_end=307,
) )
...@@ -246,8 +253,8 @@ _WORKERPARAMETER = _descriptor.Descriptor( ...@@ -246,8 +253,8 @@ _WORKERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=251, serialized_start=309,
serialized_end=332, serialized_end=390,
) )
...@@ -277,8 +284,8 @@ _SERVERPARAMETER = _descriptor.Descriptor( ...@@ -277,8 +284,8 @@ _SERVERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=334, serialized_start=392,
serialized_end=415, serialized_end=473,
) )
...@@ -308,8 +315,8 @@ _DOWNPOURWORKERPARAMETER = _descriptor.Descriptor( ...@@ -308,8 +315,8 @@ _DOWNPOURWORKERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=417, serialized_start=475,
serialized_end=496, serialized_end=554,
) )
...@@ -322,28 +329,28 @@ _DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor( ...@@ -322,28 +329,28 @@ _DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor(
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='dense_table', full_name='paddle.DownpourTrainerParameter.dense_table', index=0, name='dense_table', full_name='paddle.DownpourTrainerParameter.dense_table', index=0,
number=2, type=11, cpp_type=10, label=3, number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='sparse_table', full_name='paddle.DownpourTrainerParameter.sparse_table', index=1, name='sparse_table', full_name='paddle.DownpourTrainerParameter.sparse_table', index=1,
number=3, type=11, cpp_type=10, label=3, number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[], has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='pull_dense_per_batch', full_name='paddle.DownpourTrainerParameter.pull_dense_per_batch', index=2, name='pull_dense_per_batch', full_name='paddle.DownpourTrainerParameter.pull_dense_per_batch', index=2,
number=4, type=5, cpp_type=1, label=1, number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='push_dense_per_batch', full_name='paddle.DownpourTrainerParameter.push_dense_per_batch', index=3, name='push_dense_per_batch', full_name='paddle.DownpourTrainerParameter.push_dense_per_batch', index=3,
number=5, type=5, cpp_type=1, label=1, number=4, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
...@@ -360,8 +367,8 @@ _DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor( ...@@ -360,8 +367,8 @@ _DOWNPOURTRAINERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=499, serialized_start=557,
serialized_end=687, serialized_end=745,
) )
...@@ -412,8 +419,8 @@ _DENSETABLEPARAMETER = _descriptor.Descriptor( ...@@ -412,8 +419,8 @@ _DENSETABLEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=689, serialized_start=747,
serialized_end=812, serialized_end=870,
) )
...@@ -471,8 +478,8 @@ _SPARSETABLEPARAMETER = _descriptor.Descriptor( ...@@ -471,8 +478,8 @@ _SPARSETABLEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=814, serialized_start=872,
serialized_end=936, serialized_end=994,
) )
...@@ -509,8 +516,8 @@ _DOWNPOURSERVERPARAMETER = _descriptor.Descriptor( ...@@ -509,8 +516,8 @@ _DOWNPOURSERVERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=939, serialized_start=997,
serialized_end=1073, serialized_end=1131,
) )
...@@ -568,8 +575,8 @@ _SERVERSERVICEPARAMETER = _descriptor.Descriptor( ...@@ -568,8 +575,8 @@ _SERVERSERVICEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=1076, serialized_start=1134,
serialized_end=1221, serialized_end=1279,
) )
...@@ -634,8 +641,8 @@ _TABLEPARAMETER = _descriptor.Descriptor( ...@@ -634,8 +641,8 @@ _TABLEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=1224, serialized_start=1282,
serialized_end=1415, serialized_end=1473,
) )
...@@ -714,8 +721,8 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor( ...@@ -714,8 +721,8 @@ _TABLEACCESSORPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=1418, serialized_start=1476,
serialized_end=1787, serialized_end=1845,
) )
...@@ -787,8 +794,8 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor( ...@@ -787,8 +794,8 @@ _DOWNPOURTABLEACCESSORPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=1790, serialized_start=1848,
serialized_end=1996, serialized_end=2054,
) )
...@@ -832,8 +839,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor( ...@@ -832,8 +839,8 @@ _TABLEACCESSORSAVEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=1998, serialized_start=2056,
serialized_end=2081, serialized_end=2139,
) )
...@@ -891,8 +898,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor( ...@@ -891,8 +898,8 @@ _PSREQUESTMESSAGE = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=2083, serialized_start=2141,
serialized_end=2184, serialized_end=2242,
) )
...@@ -943,8 +950,8 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor( ...@@ -943,8 +950,8 @@ _SPARSESGDRULEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=2186, serialized_start=2244,
serialized_end=2305, serialized_end=2363,
) )
...@@ -1002,8 +1009,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor( ...@@ -1002,8 +1009,8 @@ _DENSESGDRULEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=2308, serialized_start=2366,
serialized_end=2533, serialized_end=2591,
) )
...@@ -1061,8 +1068,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor( ...@@ -1061,8 +1068,8 @@ _ADAMSGDPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=2536, serialized_start=2594,
serialized_end=2670, serialized_end=2728,
) )
...@@ -1099,8 +1106,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor( ...@@ -1099,8 +1106,8 @@ _NAIVESGDPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=2672, serialized_start=2730,
serialized_end=2738, serialized_end=2796,
) )
...@@ -1130,8 +1137,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor( ...@@ -1130,8 +1137,8 @@ _SUMMARYSGDPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=2740, serialized_start=2798,
serialized_end=2799, serialized_end=2857,
) )
...@@ -1161,8 +1168,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor( ...@@ -1161,8 +1168,8 @@ _MOVINGAVERAGERULEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=2801, serialized_start=2859,
serialized_end=2847, serialized_end=2905,
) )
...@@ -1206,8 +1213,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor( ...@@ -1206,8 +1213,8 @@ _PSRESPONSEMESSAGE = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=2849, serialized_start=2907,
serialized_end=2922, serialized_end=2980,
) )
...@@ -1280,12 +1287,13 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor( ...@@ -1280,12 +1287,13 @@ _FSCLIENTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=2925, serialized_start=2983,
serialized_end=3138, serialized_end=3196,
) )
_PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER _PSPARAMETER.fields_by_name['worker_param'].message_type = _WORKERPARAMETER
_PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER _PSPARAMETER.fields_by_name['server_param'].message_type = _SERVERPARAMETER
_PSPARAMETER.fields_by_name['trainer_param'].message_type = _DOWNPOURTRAINERPARAMETER
_PSPARAMETER.fields_by_name['fs_client_param'].message_type = _FSCLIENTPARAMETER _PSPARAMETER.fields_by_name['fs_client_param'].message_type = _FSCLIENTPARAMETER
_WORKERPARAMETER.fields_by_name['downpour_worker_param'].message_type = _DOWNPOURWORKERPARAMETER _WORKERPARAMETER.fields_by_name['downpour_worker_param'].message_type = _DOWNPOURWORKERPARAMETER
_SERVERPARAMETER.fields_by_name['downpour_server_param'].message_type = _DOWNPOURSERVERPARAMETER _SERVERPARAMETER.fields_by_name['downpour_server_param'].message_type = _DOWNPOURSERVERPARAMETER
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册