diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 4eb103cc6ba2c76949cb71f862a4ea26d4789a42..c5f1d51bd718acf32d173b97ee7bb7cdeb443c63 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -385,7 +385,7 @@ class DistributeTranspiler: # param is already created on global program param_block = None for p in self.param_grad_ep_mapping[endpoint]["params"]: - if same_or_split_var(p.name, opt_op.input(key)): + if same_or_split_var(p.name, opt_op.input(key)[0]): param_block = p break if not param_block: @@ -403,7 +403,7 @@ class DistributeTranspiler: continue # update accumulator variable shape param_shape = new_inputs["Param"].shape - var = program.global_block().vars[opt_op.input(key)] + var = program.global_block().vars[opt_op.input(key)[0]] new_shape = self._get_optimizer_input_shape(opt_op.type, key, var.shape, param_shape) tmpvar = program.global_block().create_var( @@ -440,20 +440,18 @@ class DistributeTranspiler: else: varlist = [var] for var in varlist: + # TODO(typhoonzero): will remove below line later. program.global_block().create_var( name=var.name, persistable=var.persistable, dtype=var.dtype, shape=var.shape) - try: + if not pserver_program.global_block().vars.has_key(var.name): pserver_program.global_block().create_var( name=var.name, persistable=var.persistable, dtype=var.dtype, shape=var.shape) - except ValueError: - # create var if not created yet. - pass outputs = self._get_output_map_from_op(self.program.global_block().vars, opt_op)