diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 7a2a81be9f269f262160cd082ec3a1d8e8e46811..3c6be913200716ae4f70e2b48ee8faf8078223d2 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -102,6 +102,8 @@ def split_dense_variable(var_list, the parameter server side can gain better performance. By default minimum block size is 1024. The max block size is used to prevent very large blocks that may cause send error. + :return: A list of VarBlocks. Each VarBlock specifies a shard of + the var. """ blocks = [] for var in var_list: @@ -192,22 +194,24 @@ class DistributeTranspiler: self.trainer_id = trainer_id pserver_endpoints = pservers.split(",") - # step1 + # step1: For large parameters and gradients, split them into smaller + # blocks. param_list = [pg[0] for pg in params_grads] grad_list = [pg[1] for pg in params_grads] grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) - # step2 + # step2: Create new vars for the parameters and gradients blocks and + # add ops to do the split. grad_var_mapping = self._append_split_op(program, grad_blocks) - # step3 + param_var_mapping = self._create_vars_from_blocklist(program, + param_blocks) + # step3: Add gradients as send op inputs and parameters as send + # op outputs. send_inputs = [] send_outputs = [] for b in grad_blocks: # append by order varname, block_id, _ = b.split(":") send_inputs.append(grad_var_mapping[varname][int(block_id)]) - - param_var_mapping = self._create_vars_from_blocklist(program, - param_blocks) for b in param_blocks: varname, block_id, _ = b.split(":") send_outputs.append(param_var_mapping[varname][int(block_id)]) @@ -237,7 +241,7 @@ class DistributeTranspiler: "RPCClient": rpc_client_var}, attrs={"endpoints": pserver_endpoints, "epmap": eplist}) - # step4 + # step4: Concat the parameters splits together after recv. for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: continue @@ -258,13 +262,14 @@ class DistributeTranspiler: def get_pserver_program(self, endpoint): """ Get pserver side program using the endpoint. + TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers. NOTE: assume blocks of the same variable is not distributed on the same pserver, only change param/grad varnames for trainers to fetch. """ # step1 pserver_program = Program() - # step2 + # step2: Create vars to receive vars at parameter servers. recv_inputs = [] for v in self.param_grad_ep_mapping[endpoint]["params"]: self._clone_var(pserver_program.global_block(), v) @@ -278,12 +283,6 @@ class DistributeTranspiler: orig_var_name = v.name[:suff_idx] else: orig_var_name = v.name - single_trainer_var = pserver_program.global_block().create_var( - name=orig_var_name, - persistable=True, - type=v.type, - dtype=v.dtype, - shape=v.shape) if self.trainers > 1: for trainer_id in xrange(self.trainers): var = pserver_program.global_block().create_var( @@ -294,6 +293,12 @@ class DistributeTranspiler: shape=v.shape) recv_inputs.append(var) else: + single_trainer_var = pserver_program.global_block().create_var( + name=orig_var_name, + persistable=True, + type=v.type, + dtype=v.dtype, + shape=v.shape) recv_inputs.append(single_trainer_var) # step3 @@ -344,7 +349,7 @@ class DistributeTranspiler: self._append_pserver_non_opt_ops(block, op) append_block = optimize_block - # append lr decay ops to the child block if exits + # append lr decay ops to the child block if exists lr_ops = self._get_lr_ops() if len(lr_ops) > 0: for _, op in enumerate(lr_ops): @@ -447,8 +452,10 @@ class DistributeTranspiler: block_list, add_trainer_suffix=False): """ + Create vars for each split. NOTE: only grads need to be named for different trainers, use add_trainer_suffix to rename the grad vars. + :return: A dict mapping from original var name to each var split. """ block_map = dict() var_mapping = dict() @@ -615,6 +622,7 @@ class DistributeTranspiler: type="sum", inputs={"X": vars2merge}, outputs={"Out": merged_var}) + # TODO(panyx0718): What if it's SELECTED_ROWS. if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS: optimize_block.append_op( type="scale", @@ -638,7 +646,7 @@ class DistributeTranspiler: shape=param_block.shape) new_inputs[key] = tmpvar elif key == "LearningRate": - # leraning rate variable has already be created by non-optimize op, + # learning rate variable has already be created by non-optimize op, # don't create it once again. lr_varname = opt_op.input(key)[0] if pserver_block.vars.has_key(lr_varname): @@ -773,6 +781,7 @@ class DistributeTranspiler: return False def _get_input_map_from_op(self, varmap, op): + """Returns a dict from op input name to the vars in varmap.""" iomap = dict() for key in op.input_names: vars = [] @@ -785,6 +794,7 @@ class DistributeTranspiler: return iomap def _get_output_map_from_op(self, varmap, op): + """Returns a dict from op output name to the vars in varmap.""" iomap = dict() for key in op.output_names: vars = [] @@ -812,6 +822,7 @@ class DistributeTranspiler: find_ops.append(op) # make a union find struct by the ops in default_main_program ufind = UnionFind(block.ops) + for op1 in block.ops: for op2 in block.ops: # NOTE: we need to skip all optimize ops, since it is connected