From 9d3114c8c99b46a90b76811b3a2b19b1c1e31919 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 4 Jun 2018 18:42:44 +0800 Subject: [PATCH] fix transpiler error --- .../paddle/fluid/transpiler/distribute_transpiler.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index da001add8e1..27992df462f 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -187,12 +187,17 @@ class DistributeTranspiler: param_list = [] grad_list = [] + param_grad_set = set() for p, g in self.params_grads: # skip parameter marked not trainable if type(p) == Parameter and p.trainable == False: continue - param_list.append(p) - grad_list.append(g) + if p.name not in param_grad_set: + param_list.append(p) + param_grad_set.add(p.name) + if g.name not in param_grad_set: + grad_list.append(g) + param_grad_set.add(g.name) self._update_dist_lookup_table_vars(param_list, grad_list, self.params_grads) @@ -829,6 +834,9 @@ class DistributeTranspiler: if not block_map.has_key(varname): block_map[varname] = [] block_map[varname].append((long(offset), long(size))) + # Do not remove this important debug message: + print("block map: %s" % block_map) + for varname, splited in block_map.iteritems(): orig_var = program.global_block().var(varname) if len(splited) == 1: -- GitLab