From c74445017d35ad344c5fc2a19a35c47a72358f3c Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 7 Feb 2018 19:18:48 +0800 Subject: [PATCH] refine distribute transpiler --- .../paddle/v2/fluid/distribute_transpiler.py | 124 +++++++++++++----- 1 file changed, 89 insertions(+), 35 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 121b407cae4..4eb103cc6ba 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -300,6 +300,9 @@ class DistributeTranspiler: pass return orig_shape + def _op_input_var(self, op, varname): + pass + def _is_op_on_pserver(self, endpoint, all_ops, idx): """ Recursively check if the op need to run on current server. @@ -309,29 +312,35 @@ class DistributeTranspiler: p.name for p in self.param_grad_ep_mapping[endpoint]["params"] ] op = all_ops[idx] - if op.inputs.has_key("Param"): - if op.inputs["Param"].name in param_names: + input_names = set(op.input_names) + # TODO(typhoonzero): using Param and Grad input name to identify + # that the operator is an optimization operator, need a better way. + if "Param" in input_names: + if op.input("Param")[0] in param_names: return True else: for n in param_names: - if same_or_split_var(n, op.inputs[ - "Param"].name) and n != op.inputs["Param"].name: + if same_or_split_var(n, op.input("Param")[0]) \ + and n != op.input("Param")[0]: return True return False else: j = idx - 1 while j >= 0: prev_op = all_ops[j] - prev_output_names = [o.name for o in prev_op.outputs.values()] - prev_input_names = [o.name for o in prev_op.inputs.values()] + # prev_output_names = [o.name for o in prev_op.outputs.values()] + # prev_input_names = [o.name for o in prev_op.inputs.values()] + # NOTE(typhoonzero): consider list input/output + prev_output_names = prev_op.desc.output_arg_names() + prev_input_names = prev_op.desc.input_arg_names() found1 = False found2 = False - for _, v in op.inputs.iteritems(): - if v.name in prev_output_names: + for varname in op.desc.input_arg_names(): + if varname in prev_output_names: found1 = self._is_op_on_pserver(endpoint, all_ops, j) # later ops may produce output for prev op's next batch use. - for _, v in op.outputs.iteritems(): - if v.name in prev_input_names: + for varname in op.desc.output_arg_names(): + if varname in prev_input_names: found2 = self._is_op_on_pserver(endpoint, all_ops, j) if found1 or found2: return True @@ -342,11 +351,11 @@ class DistributeTranspiler: new_inputs = dict() # update param/grad shape first, then other inputs like # moment can use the updated shape - for key, var in opt_op.inputs.iteritems(): + for key in opt_op.input_names: if key == "Grad": grad_block = None for g in self.param_grad_ep_mapping[endpoint]["grads"]: - if same_or_split_var(g.name, var.name): + if same_or_split_var(g.name, opt_op.input(key)[0]): grad_block = g break if not grad_block: @@ -376,7 +385,7 @@ class DistributeTranspiler: # param is already created on global program param_block = None for p in self.param_grad_ep_mapping[endpoint]["params"]: - if same_or_split_var(p.name, var.name): + if same_or_split_var(p.name, opt_op.input(key)): param_block = p break if not param_block: @@ -389,11 +398,12 @@ class DistributeTranspiler: new_inputs[key] = tmpvar - for key, var in opt_op.inputs.iteritems(): + for key in opt_op.input_names: if key in ["Param", "Grad"]: continue # update accumulator variable shape param_shape = new_inputs["Param"].shape + var = program.global_block().vars[opt_op.input(key)] new_shape = self._get_optimizer_input_shape(opt_op.type, key, var.shape, param_shape) tmpvar = program.global_block().create_var( @@ -412,30 +422,46 @@ class DistributeTranspiler: shape=new_shape) # change output's ParamOut variable - opt_op.outputs["ParamOut"] = new_inputs["Param"] + outputs = self._get_output_map_from_op(program.global_block(), opt_op) + outputs["ParamOut"] = new_inputs["Param"] program.global_block().append_op( type=opt_op.type, inputs=new_inputs, - outputs=opt_op.outputs, + outputs=outputs, attrs=opt_op.attrs) def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op): # Append the ops for parameters that do not need to be optimized/updated - for _, var in opt_op.inputs.iteritems(): - program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) - pserver_program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) + inputs = self._get_input_map_from_op(self.program.global_block().vars, + opt_op) + for var in inputs.itervalues(): + if type(var) == list: + varlist = var + else: + varlist = [var] + for var in varlist: + program.global_block().create_var( + name=var.name, + persistable=var.persistable, + dtype=var.dtype, + shape=var.shape) + try: + pserver_program.global_block().create_var( + name=var.name, + persistable=var.persistable, + dtype=var.dtype, + shape=var.shape) + except ValueError: + # create var if not created yet. + pass + + outputs = self._get_output_map_from_op(self.program.global_block().vars, + opt_op) + program.global_block().append_op( type=opt_op.type, - inputs=opt_op.inputs, - outputs=opt_op.outputs, + inputs=inputs, + outputs=outputs, attrs=opt_op.attrs) def get_pserver_program(self, endpoint): @@ -472,7 +498,7 @@ class DistributeTranspiler: self.optimize_ops, idx) if not is_op_on_pserver: continue - if opt_op.inputs.has_key("Grad"): + if "Grad" in opt_op.desc.input_arg_names(): self._append_pserver_ops(optimize_sub_program, pserver_program, opt_op, endpoint) else: @@ -499,6 +525,30 @@ class DistributeTranspiler: pserver_program.sync_with_cpp() return pserver_program + def _get_input_map_from_op(self, varmap, op): + iomap = dict() + for key in op.input_names: + vars = [] + for varname in op.input(key): + vars.append(varmap[varname]) + if len(vars) == 1: + iomap[key] = vars[0] + else: + iomap[key] = vars + return iomap + + def _get_output_map_from_op(self, varmap, op): + iomap = dict() + for key in op.output_names: + vars = [] + for varname in op.output(key): + vars.append(varmap[varname]) + if len(vars) == 1: + iomap[key] = vars[0] + else: + iomap[key] = vars + return iomap + def get_startup_program(self, endpoint, pserver_program): """ Get startup program for current parameter server. @@ -529,17 +579,21 @@ class DistributeTranspiler: # 2. rename op outputs for op in orig_s_prog.global_block().ops: + new_inputs = dict() new_outputs = dict() # do not append startup op if var is not on this pserver op_on_pserver = False - for key, var in op.outputs.iteritems(): - newname, _ = _get_splited_name_and_shape(var.name) + for key in op.output_names: + newname, _ = _get_splited_name_and_shape(op.output(key)[0]) if newname: op_on_pserver = True new_outputs[key] = created_var_map[newname] - elif var.name in pserver_vars: + elif op.output(key)[0] in pserver_vars: op_on_pserver = True - new_outputs[key] = pserver_vars[var.name] + new_outputs[key] = pserver_vars[op.output(key)[0]] + + # most startup program ops have no inputs + new_inputs = self._get_input_map_from_op(pserver_vars, op) if op_on_pserver: if op.type in [ @@ -548,7 +602,7 @@ class DistributeTranspiler: op.attrs["shape"] = new_outputs["Out"].shape s_prog.global_block().append_op( type=op.type, - inputs=op.inputs, + inputs=new_inputs, outputs=new_outputs, attrs=op.attrs) return s_prog -- GitLab