From d02b17e597cc83988b35f09b929a8f7701607310 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 9 Apr 2018 20:50:44 +0800 Subject: [PATCH] fix dist transpiler bug --- python/paddle/fluid/distribute_transpiler.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 3c6be913200..c5b0ffb8547 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -278,6 +278,15 @@ class DistributeTranspiler: # we don't need to create them when grad arrives. # change client side var name to origin name by # removing ".trainer_%d" suffix + # NOTE: single_trainer_var must be created for multi-trainer + # case to merge grads from multiple trainers + single_trainer_var = \ + pserver_program.global_block().create_var( + name=orig_var_name, + persistable=True, + type=v.type, + dtype=v.dtype, + shape=v.shape) suff_idx = v.name.find(".trainer_") if suff_idx >= 0: orig_var_name = v.name[:suff_idx] @@ -293,12 +302,6 @@ class DistributeTranspiler: shape=v.shape) recv_inputs.append(var) else: - single_trainer_var = pserver_program.global_block().create_var( - name=orig_var_name, - persistable=True, - type=v.type, - dtype=v.dtype, - shape=v.shape) recv_inputs.append(single_trainer_var) # step3 -- GitLab