未验证 提交 9f0dcfd4 编写于 作者: T tangwei12 提交者: GitHub

Merge pull request #11155 from typhoonzero/fix_transpiler_merged_bug

Fix single pserver transpile error after merging
...@@ -187,12 +187,17 @@ class DistributeTranspiler: ...@@ -187,12 +187,17 @@ class DistributeTranspiler:
param_list = [] param_list = []
grad_list = [] grad_list = []
param_grad_set = set()
for p, g in self.params_grads: for p, g in self.params_grads:
# skip parameter marked not trainable # skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False: if type(p) == Parameter and p.trainable == False:
continue continue
param_list.append(p) if p.name not in param_grad_set:
grad_list.append(g) param_list.append(p)
param_grad_set.add(p.name)
if g.name not in param_grad_set:
grad_list.append(g)
param_grad_set.add(g.name)
self._update_dist_lookup_table_vars(param_list, grad_list, self._update_dist_lookup_table_vars(param_list, grad_list,
self.params_grads) self.params_grads)
...@@ -829,6 +834,9 @@ class DistributeTranspiler: ...@@ -829,6 +834,9 @@ class DistributeTranspiler:
if not block_map.has_key(varname): if not block_map.has_key(varname):
block_map[varname] = [] block_map[varname] = []
block_map[varname].append((long(offset), long(size))) block_map[varname].append((long(offset), long(size)))
# Do not remove this important debug message:
print("block map: %s" % block_map)
for varname, splited in block_map.iteritems(): for varname, splited in block_map.iteritems():
orig_var = program.global_block().var(varname) orig_var = program.global_block().var(varname)
if len(splited) == 1: if len(splited) == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册