diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index b064220ca29c5a7154f53a40a53028eb54940996..75e103cb80c39d0378582b7d4c0bd82cfcc6fa75 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -257,7 +257,45 @@ class DistributeTranspiler: pass return orig_shape - def _append_pserver_ops(self, program, opt_op, endpoint): + def _is_op_on_pserver(self, endpoint, all_ops, idx): + """ + Recursively check if the op need to run on current server. + Assume that ops are in the execution order. + """ + param_names = [ + p.name for p in self.param_grad_ep_mapping[endpoint]["params"] + ] + op = all_ops[idx] + if op.inputs.has_key("Param"): + if op.inputs["Param"].name in param_names: + return True + else: + for n in param_names: + if n.startswith(op.inputs["Param"].name+".block") and \ + n != op.inputs["Param"].name: + return True + return False + else: + j = idx - 1 + while j >= 0: + prev_op = all_ops[j] + prev_output_names = [o.name for o in prev_op.outputs.values()] + prev_input_names = [o.name for o in prev_op.inputs.values()] + found1 = False + found2 = False + for _, v in op.inputs.iteritems(): + if v.name in prev_output_names: + found1 = self._is_op_on_pserver(endpoint, all_ops, j) + # later ops may produce output for prev op's next batch use. + for _, v in op.outputs.iteritems(): + if v.name in prev_input_names: + found2 = self._is_op_on_pserver(endpoint, all_ops, j) + if found1 or found2: + return True + j -= 1 + return False + + def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint): new_inputs = dict() # update param/grad shape first, then other inputs like # moment can use the updated shape @@ -321,6 +359,14 @@ class DistributeTranspiler: dtype=var.dtype, shape=new_shape) new_inputs[key] = tmpvar + # create var in pserver program global block. + # TODO(typhoonzero): put blocks in one program to avoid create two + # variables. + pserver_program.global_block().create_var( + name=var.name, + persistable=var.persistable, + dtype=var.dtype, + shape=new_shape) # change outputs ParamOut variable opt_op.outputs["ParamOut"] = new_inputs["Param"] @@ -330,13 +376,18 @@ class DistributeTranspiler: outputs=opt_op.outputs, attrs=opt_op.attrs) - def _append_pserver_non_opt_ops(self, program, opt_op): + def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op): for _, var in opt_op.inputs.iteritems(): program.global_block().create_var( name=var.name, persistable=var.persistable, dtype=var.dtype, shape=var.shape) + pserver_program.global_block().create_var( + name=var.name, + persistable=var.persistable, + dtype=var.dtype, + shape=var.shape) program.global_block().append_op( type=opt_op.type, inputs=opt_op.inputs, @@ -358,13 +409,18 @@ class DistributeTranspiler: self._clone_var(pserver_program.global_block(), v) # step6 optimize_sub_program = Program() - for opt_op in optimize_ops: + for idx, opt_op in enumerate(optimize_ops): + is_op_on_pserver = self._is_op_on_pserver(endpoint, optimize_ops, + idx) + if not is_op_on_pserver: + continue if opt_op.inputs.has_key("Grad"): - # append optimize_op - self._append_pserver_ops(optimize_sub_program, opt_op, endpoint) + self._append_pserver_ops(optimize_sub_program, pserver_program, + opt_op, endpoint) else: - self._append_pserver_non_opt_ops(optimize_sub_program, opt_op) - + self._append_pserver_non_opt_ops(optimize_sub_program, + pserver_program, opt_op) + print("****subprogram", optimize_sub_program) pserver_program.global_block().append_op( type="recv", inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"] @@ -386,7 +442,7 @@ class DistributeTranspiler: pserver_program.sync_with_cpp() return pserver_program - def get_startup_program(self, endpoint): + def get_startup_program(self, endpoint, pserver_program): """ Get startup program for current parameter server. Modify operator input variables if there are variables that @@ -405,13 +461,17 @@ class DistributeTranspiler: # 1. create vars created_var_map = dict() - for var in params: + for _, var in pserver_program.global_block().vars.iteritems(): + print("create var for startup", var.name, var.shape) tmpvar = s_prog.global_block().create_var( name=var.name, - persistable=True, + persistable=var.persistable, dtype=var.dtype, shape=var.shape) created_var_map[var.name] = tmpvar + optimize_op_input_var_names = [ + v.name for v in pserver_program.global_block().vars.values() + ] # 2. rename op outputs for op in orig_s_prog.global_block().ops: @@ -423,13 +483,16 @@ class DistributeTranspiler: else: new_outputs[key] = var # do not append startup op if var is not on this pserver - var_on_pserver = False - for _, var in new_outputs.iteritems(): - if var.name in created_var_map: - var_on_pserver = True - if var_on_pserver: + op_on_pserver = False + for _, var in op.outputs.iteritems(): + if var.name in optimize_op_input_var_names: + op_on_pserver = True + break + + if op_on_pserver: # gaussian_random use attr to determine tensor shape - op.attrs["shape"] = new_outputs["Out"].shape + if op.type in ["gaussian_random", "fill_constant"]: + op.attrs["shape"] = new_outputs["Out"].shape s_prog.global_block().append_op( type=op.type, inputs=op.inputs,