From 5135f05cf19c30af52a92637d2270938a0e01585 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Fri, 9 Feb 2018 16:05:31 +0800 Subject: [PATCH] create optimize block in pserver program --- .../paddle/v2/fluid/distribute_transpiler.py | 43 +++++++------------ 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index c5f1d51bd71..cd89dba72db 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -347,7 +347,8 @@ class DistributeTranspiler: j -= 1 return False - def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint): + def _append_pserver_ops(self, optimize_block, opt_op, endpoint): + program = optimize_block.program new_inputs = dict() # update param/grad shape first, then other inputs like # moment can use the updated shape @@ -371,11 +372,11 @@ class DistributeTranspiler: if self.trainers > 1: vars2merge = self._create_var_for_trainers( program.global_block(), grad_block, self.trainers) - program.global_block().append_op( + optimize_block.append_op( type="sum", inputs={"X": vars2merge}, outputs={"Out": merged_var}) - program.global_block().append_op( + optimize_block.append_op( type="scale", inputs={"X": merged_var}, outputs={"Out": merged_var}, @@ -412,25 +413,18 @@ class DistributeTranspiler: dtype=var.dtype, shape=new_shape) new_inputs[key] = tmpvar - # create var in pserver program global block. - # TODO(typhoonzero): put blocks in one program to avoid create two - # variables. - pserver_program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=new_shape) # change output's ParamOut variable outputs = self._get_output_map_from_op(program.global_block(), opt_op) outputs["ParamOut"] = new_inputs["Param"] - program.global_block().append_op( + optimize_block.append_op( type=opt_op.type, inputs=new_inputs, outputs=outputs, attrs=opt_op.attrs) - def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op): + def _append_pserver_non_opt_ops(self, optimize_block, opt_op): + program = optimize_block.program # Append the ops for parameters that do not need to be optimized/updated inputs = self._get_input_map_from_op(self.program.global_block().vars, opt_op) @@ -440,14 +434,8 @@ class DistributeTranspiler: else: varlist = [var] for var in varlist: - # TODO(typhoonzero): will remove below line later. - program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) - if not pserver_program.global_block().vars.has_key(var.name): - pserver_program.global_block().create_var( + if not program.global_block().vars.has_key(var.name): + program.global_block().create_var( name=var.name, persistable=var.persistable, dtype=var.dtype, @@ -456,7 +444,7 @@ class DistributeTranspiler: outputs = self._get_output_map_from_op(self.program.global_block().vars, opt_op) - program.global_block().append_op( + optimize_block.append_op( type=opt_op.type, inputs=inputs, outputs=outputs, @@ -489,7 +477,7 @@ class DistributeTranspiler: dtype=v.dtype, shape=v.shape) # step6 - optimize_sub_program = Program() + optimize_block = pserver_program.create_block(0) # Iterate through the ops and append ops as needed for idx, opt_op in enumerate(self.optimize_ops): is_op_on_pserver = self._is_op_on_pserver(endpoint, @@ -497,18 +485,17 @@ class DistributeTranspiler: if not is_op_on_pserver: continue if "Grad" in opt_op.desc.input_arg_names(): - self._append_pserver_ops(optimize_sub_program, pserver_program, - opt_op, endpoint) + self._append_pserver_ops(optimize_block, opt_op, endpoint) else: - self._append_pserver_non_opt_ops(optimize_sub_program, - pserver_program, opt_op) + self._append_pserver_non_opt_ops(optimize_block, opt_op) + # Append the listen_and_serv op pserver_program.global_block().append_op( type="listen_and_serv", inputs={}, outputs={}, attrs={ - "OptimizeBlock": optimize_sub_program.global_block(), + "OptimizeBlock": optimize_block, "endpoint": endpoint, "ParamList": [ p.name -- GitLab