From 92313a99c046adc7caf896f339804efd89006b37 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 9 Apr 2018 20:58:49 +0800 Subject: [PATCH] update --- python/paddle/fluid/distribute_transpiler.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index c5b0ffb854..0ec3ebc7e3 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -278,6 +278,12 @@ class DistributeTranspiler: # we don't need to create them when grad arrives. # change client side var name to origin name by # removing ".trainer_%d" suffix + + suff_idx = v.name.find(".trainer_") + if suff_idx >= 0: + orig_var_name = v.name[:suff_idx] + else: + orig_var_name = v.name # NOTE: single_trainer_var must be created for multi-trainer # case to merge grads from multiple trainers single_trainer_var = \ @@ -287,11 +293,6 @@ class DistributeTranspiler: type=v.type, dtype=v.dtype, shape=v.shape) - suff_idx = v.name.find(".trainer_") - if suff_idx >= 0: - orig_var_name = v.name[:suff_idx] - else: - orig_var_name = v.name if self.trainers > 1: for trainer_id in xrange(self.trainers): var = pserver_program.global_block().create_var( -- GitLab