From 7a6000a0b879719ea25e4c882ae6be79845ee57f Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 8 Feb 2018 12:08:13 +0800 Subject: [PATCH] follow comments --- python/paddle/v2/fluid/distribute_transpiler.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 4eb103cc6b..c5f1d51bd7 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -385,7 +385,7 @@ class DistributeTranspiler: # param is already created on global program param_block = None for p in self.param_grad_ep_mapping[endpoint]["params"]: - if same_or_split_var(p.name, opt_op.input(key)): + if same_or_split_var(p.name, opt_op.input(key)[0]): param_block = p break if not param_block: @@ -403,7 +403,7 @@ class DistributeTranspiler: continue # update accumulator variable shape param_shape = new_inputs["Param"].shape - var = program.global_block().vars[opt_op.input(key)] + var = program.global_block().vars[opt_op.input(key)[0]] new_shape = self._get_optimizer_input_shape(opt_op.type, key, var.shape, param_shape) tmpvar = program.global_block().create_var( @@ -440,20 +440,18 @@ class DistributeTranspiler: else: varlist = [var] for var in varlist: + # TODO(typhoonzero): will remove below line later. program.global_block().create_var( name=var.name, persistable=var.persistable, dtype=var.dtype, shape=var.shape) - try: + if not pserver_program.global_block().vars.has_key(var.name): pserver_program.global_block().create_var( name=var.name, persistable=var.persistable, dtype=var.dtype, shape=var.shape) - except ValueError: - # create var if not created yet. - pass outputs = self._get_output_map_from_op(self.program.global_block().vars, opt_op) -- GitLab