From 3759600019f206794d5852bbbc74fd959337cf3d Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Thu, 13 Dec 2018 23:01:53 +0800 Subject: [PATCH] add doc string for downpour.py and distribute_lookup_table.py --- .../paddle/fluid/distribute_lookup_table.py | 32 ++++++++++++--- python/paddle/fluid/distributed/downpour.py | 41 ++++++++++++++----- 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/distribute_lookup_table.py b/python/paddle/fluid/distribute_lookup_table.py index 243d806c41..74824f6832 100644 --- a/python/paddle/fluid/distribute_lookup_table.py +++ b/python/paddle/fluid/distribute_lookup_table.py @@ -16,31 +16,51 @@ LOOKUP_TABLE_TYPE = "lookup_table" def find_distributed_lookup_table_inputs(program, table_name): + """ + Find input variable of distribute lookup table in program. + We only support one distribute table now. + Args: + program(Program): given program, locate distributed lookup table + table_name(str): given table name that is found beforehand + Returns: + inputs + """ local_vars = program.current_block().vars inputs = [] for op in program.global_block().ops: if op.type == LOOKUP_TABLE_TYPE: if table_name == op.input("W")[0]: - inputs.extend( - [local_vars[name] for name in op.input("Ids")]) + inputs.extend([local_vars[name] for name in op.input("Ids")]) return inputs + def find_distributed_lookup_table_outputs(program, table_name): + """ + Find output variable of distribute lookup table in program. + We only support one distribute table now. + Args: + program(Program): given program, locate distributed lookup table + table_name(str): given table name that is found beforehand + Returns: + outputs + """ local_vars = program.current_block().vars outputs = [] for op in program.global_block().ops: if op.type == LOOKUP_TABLE_TYPE: if table_name == op.input("W")[0]: - outputs.extend( - [local_vars[name] for name in op.output("Out")]) + outputs.extend([local_vars[name] for name in op.output("Out")]) return outputs + def find_distributed_lookup_table(program): """ Find distribute lookup table in program. We only support one distribute table now. - :param program: - :return: table_name or None + Args: + program(Program): given program, locate distributed lookup table + Returns: + table_name or None """ table_name = None diff --git a/python/paddle/fluid/distributed/downpour.py b/python/paddle/fluid/distributed/downpour.py index 9ef9e14ccc..87dfab92c5 100644 --- a/python/paddle/fluid/distributed/downpour.py +++ b/python/paddle/fluid/distributed/downpour.py @@ -20,6 +20,7 @@ from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_i from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs from google.protobuf import text_format + class DownpourSGD(object): """ Distributed optimizer of downpour stochastic gradient descent @@ -35,17 +36,38 @@ class DownpourSGD(object): downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2) downpour_sgd.minimize(cost) """ + def __init__(self, learning_rate=0.001, window=1): # todo(guru4elephant): add more optimizers here as argument # todo(guru4elephant): make learning_rate as a variable self.learning_rate_ = learning_rate self.window_ = window self.type = "downpour" - - def minimize(self, loss, startup_program=None, - parameter_list=None, no_grad_set=None): - params_grads = sorted(append_backward( - loss, parameter_list, no_grad_set), key=lambda x:x[0].name) + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + """ + DownpounSGD is a distributed optimizer so + that user can call minimize to generate backward + operators and optimization operators within minmize function + Args: + loss(Variable): loss variable defined by user + startup_program(Program): startup program that defined by user + parameter_list(str list): parameter names defined by users + no_grad_set(set): a set of variables that is defined by users + so that these variables do not need gradient computation + Returns: + [ps_param, worker_skipped_ops] + ps_param: parameter server protobuf desc + worker_skipped_ops: operator names that need + to be skipped during execution + """ + params_grads = sorted( + append_backward(loss, parameter_list, no_grad_set), + key=lambda x: x[0].name) table_name = find_distributed_lookup_table(loss.block.program) prefetch_slots = find_distributed_lookup_table_inputs( loss.block.program, table_name) @@ -67,12 +89,12 @@ class DownpourSGD(object): grads.append(i[1]) server.add_sparse_table(sparse_table_index, self.learning_rate_, prefetch_slots, prefetch_slots_emb) - server.add_dense_table(dense_table_index, self.learning_rate_, - params, grads) + server.add_dense_table(dense_table_index, self.learning_rate_, params, + grads) worker.add_sparse_table(sparse_table_index, self.learning_rate_, prefetch_slots, prefetch_slots_emb) - worker.add_dense_table(dense_table_index, self.learning_rate_, - params, grads) + worker.add_dense_table(dense_table_index, self.learning_rate_, params, + grads) ps_param = pslib.PSParameter() ps_param.server_param.CopyFrom(server.get_desc()) ps_param.trainer_param.CopyFrom(worker.get_desc()) @@ -80,5 +102,4 @@ class DownpourSGD(object): # currently only support lookup_table worker_skipped_ops = ["lookup_table", "lookup_table_grad"] ps_param.trainer_param.skip_op.extend(worker_skipped_ops) - ps_param_str = text_format.MessageToString(ps_param) return [ps_param, worker_skipped_ops] -- GitLab