diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 3c6be913200716ae4f70e2b48ee8faf8078223d2..0ec3ebc7e3dba6e4cf89c8a76622761d210276cf 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -278,11 +278,21 @@ class DistributeTranspiler: # we don't need to create them when grad arrives. # change client side var name to origin name by # removing ".trainer_%d" suffix + suff_idx = v.name.find(".trainer_") if suff_idx >= 0: orig_var_name = v.name[:suff_idx] else: orig_var_name = v.name + # NOTE: single_trainer_var must be created for multi-trainer + # case to merge grads from multiple trainers + single_trainer_var = \ + pserver_program.global_block().create_var( + name=orig_var_name, + persistable=True, + type=v.type, + dtype=v.dtype, + shape=v.shape) if self.trainers > 1: for trainer_id in xrange(self.trainers): var = pserver_program.global_block().create_var( @@ -293,12 +303,6 @@ class DistributeTranspiler: shape=v.shape) recv_inputs.append(var) else: - single_trainer_var = pserver_program.global_block().create_var( - name=orig_var_name, - persistable=True, - type=v.type, - dtype=v.dtype, - shape=v.shape) recv_inputs.append(single_trainer_var) # step3