diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 3c6be913200716ae4f70e2b48ee8faf8078223d2..c5b0ffb854726721d945c90cded3c7121a3a6840 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -278,6 +278,15 @@ class DistributeTranspiler: # we don't need to create them when grad arrives. # change client side var name to origin name by # removing ".trainer_%d" suffix + # NOTE: single_trainer_var must be created for multi-trainer + # case to merge grads from multiple trainers + single_trainer_var = \ + pserver_program.global_block().create_var( + name=orig_var_name, + persistable=True, + type=v.type, + dtype=v.dtype, + shape=v.shape) suff_idx = v.name.find(".trainer_") if suff_idx >= 0: orig_var_name = v.name[:suff_idx] @@ -293,12 +302,6 @@ class DistributeTranspiler: shape=v.shape) recv_inputs.append(var) else: - single_trainer_var = pserver_program.global_block().create_var( - name=orig_var_name, - persistable=True, - type=v.type, - dtype=v.dtype, - shape=v.shape) recv_inputs.append(single_trainer_var) # step3