From 5e853840c4d0f59f7a8172fe4d21d7fa6431e984 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 4 Jun 2018 18:42:44 +0800 Subject: [PATCH] fix transpiler error --- .../paddle/fluid/transpiler/distribute_transpiler.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 06b0a1375..819073257 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -186,12 +186,17 @@ class DistributeTranspiler: param_list = [] grad_list = [] + param_grad_set = set() for p, g in self.params_grads: # skip parameter marked not trainable if type(p) == Parameter and p.trainable == False: continue - param_list.append(p) - grad_list.append(g) + if p.name not in param_grad_set: + param_list.append(p) + param_grad_set.add(p.name) + if g.name not in param_grad_set: + grad_list.append(g) + param_grad_set.add(g.name) self._update_dist_lookup_table_vars(param_list, grad_list, self.params_grads) @@ -802,6 +807,9 @@ class DistributeTranspiler: if not block_map.has_key(varname): block_map[varname] = [] block_map[varname].append((long(offset), long(size))) + # Do not remove this important debug message: + print("block map: %s" % block_map) + for varname, splited in block_map.iteritems(): orig_var = program.global_block().var(varname) if len(splited) == 1: -- GitLab