diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index ac13a7cb60a306a4c542b192fe97d7009c2192ef..76e8734f1371d0874d8b10f597a9bb6989de9297 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -114,12 +114,15 @@ class DistributeTranspiler: # step3 send_inputs = [] send_outputs = [] - for _, splited in grad_var_mapping.iteritems(): - send_inputs.extend(splited) + for b in grad_blocks: # append by order + varname, block_id, _ = b.split(":") + send_inputs.append(grad_var_mapping[varname][int(block_id)]) + param_var_mapping = self._create_vars_from_blocklist(program, param_blocks) - for _, splited in param_var_mapping.iteritems(): - send_outputs.extend(splited) + for b in param_blocks: + varname, block_id, _ = b.split(":") + send_outputs.append(param_var_mapping[varname][int(block_id)]) # let send_op know which endpoint to send which var, eplist is of the same # order of send_inputs. eplist = split_method(send_inputs, pserver_endpoints) @@ -243,8 +246,37 @@ class DistributeTranspiler: var_list.append(var_each) return var_list - def _append_pserver_ops(self, opt_op, endpoint): + def _get_optimizer_input_shape(self, op_type, varkey, orig_shape, + param_shape): + """ + Returns the shape for optimizer inputs that need to be reshaped when + Param and Grad is splited to multiple servers. + """ + # HACK(typhoonzero): Should use functions of corresponding optimizer in + # optimizer.py to get the shape, do not bind this in the transpiler. + if op_type == "adam": + if varkey in ["Moment1", "Moment2"]: + return param_shape + elif op_type == "adagrad": + if varkey == "Moment": + return param_shape + elif op_type == "adamax": + if varkey in ["Moment", "InfNorm"]: + return param_shape + elif op_type == "momentum": + if varkey == "Velocity": + return param_shape + elif op_type == "": + if varkey == "Moment": + return param_shape + elif op_type == "sgd": + pass + return orig_shape + + def _append_pserver_ops(self, program, opt_op, endpoint): new_inputs = dict() + # update param/grad shape first, then other inputs like + # moment can use the updated shape for key, var in opt_op.inputs.iteritems(): if key == "Grad": grad_block = None @@ -256,7 +288,7 @@ class DistributeTranspiler: # do not append this op if current endpoint # is not dealing with this grad block return - merged_var = optimize_sub_program.global_block().create_var( + merged_var = program.global_block().create_var( name=grad_block.name, persistable=grad_block.persistable, dtype=grad_block.dtype, @@ -264,13 +296,12 @@ class DistributeTranspiler: # append merging ops if trainers > 1 if self.trainers > 1: vars2merge = self._create_var_for_trainers( - optimize_sub_program.global_block(), grad_block, - self.trainers) - optimize_sub_program.global_block().append_op( + program.global_block(), grad_block, self.trainers) + program.global_block().append_op( type="sum", inputs={"X": vars2merge}, outputs={"Out": merged_var}) - optimize_sub_program.global_block().append_op( + program.global_block().append_op( type="scale", inputs={"X": merged_var}, outputs={"Out": merged_var}, @@ -285,37 +316,45 @@ class DistributeTranspiler: break if not param_block: return - tmpvar = optimize_sub_program.global_block().create_var( + tmpvar = program.global_block().create_var( name=param_block.name, persistable=param_block.persistable, dtype=param_block.dtype, shape=param_block.shape) new_inputs[key] = tmpvar - else: - tmpvar = optimize_sub_program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) - new_inputs[key] = tmpvar + + for key, var in opt_op.inputs.iteritems(): + if key in ["Param", "Grad"]: + continue + # update accumulator variable shape + param_shape = new_inputs["Param"].shape + new_shape = self._get_optimizer_input_shape(opt_op.type, key, + var.shape, param_shape) + print("var, new shape", key, var.name, new_shape) + tmpvar = program.global_block().create_var( + name=var.name, + persistable=var.persistable, + dtype=var.dtype, + shape=new_shape) + new_inputs[key] = tmpvar # FIXME: change outputs ParamOut - optimize_sub_program.global_block().append_op( + program.global_block().append_op( type=opt_op.type, inputs=new_inputs, outputs=opt_op.outputs, attrs=opt_op.attrs) - def _append_pserver_non_opt_ops(self, opt_op): + def _append_pserver_non_opt_ops(self, program, opt_op): for _, var in opt_op.inputs.iteritems(): - optimize_sub_program.global_block().create_var( + program.global_block().create_var( name=var.name, persistable=var.persistable, dtype=var.dtype, shape=var.shape) - optimize_sub_program.global_block().append_op( + program.global_block().append_op( type=opt_op.type, - inputs=new_inputs, + inputs=opt_op.inputs, outputs=opt_op.outputs, attrs=opt_op.attrs) @@ -331,15 +370,15 @@ class DistributeTranspiler: # step5 pserver_program = Program() for v in self.param_grad_ep_mapping[endpoint]["params"]: - self._clone_param(pserver_program.global_block(), v) + self._clone_var(pserver_program.global_block(), v) # step6 optimize_sub_program = Program() for opt_op in optimize_ops: - if opt_ops.inputs.has_key("Grad"): + if opt_op.inputs.has_key("Grad"): # append optimize_op - self._append_pserver_ops(opt_op, endpoint) + self._append_pserver_ops(optimize_sub_program, opt_op, endpoint) else: - self._append_pserver_non_opt_ops(opt_op) + self._append_pserver_non_opt_ops(optimize_sub_program, opt_op) pserver_program.global_block().append_op( type="recv",