提交 5e853840 编写于 作者: Y yi.wu 提交者: typhoonzero

fix transpiler error

上级 997e8b9b
...@@ -186,12 +186,17 @@ class DistributeTranspiler: ...@@ -186,12 +186,17 @@ class DistributeTranspiler:
param_list = [] param_list = []
grad_list = [] grad_list = []
param_grad_set = set()
for p, g in self.params_grads: for p, g in self.params_grads:
# skip parameter marked not trainable # skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False: if type(p) == Parameter and p.trainable == False:
continue continue
param_list.append(p) if p.name not in param_grad_set:
grad_list.append(g) param_list.append(p)
param_grad_set.add(p.name)
if g.name not in param_grad_set:
grad_list.append(g)
param_grad_set.add(g.name)
self._update_dist_lookup_table_vars(param_list, grad_list, self._update_dist_lookup_table_vars(param_list, grad_list,
self.params_grads) self.params_grads)
...@@ -802,6 +807,9 @@ class DistributeTranspiler: ...@@ -802,6 +807,9 @@ class DistributeTranspiler:
if not block_map.has_key(varname): if not block_map.has_key(varname):
block_map[varname] = [] block_map[varname] = []
block_map[varname].append((long(offset), long(size))) block_map[varname].append((long(offset), long(size)))
# Do not remove this important debug message:
print("block map: %s" % block_map)
for varname, splited in block_map.iteritems(): for varname, splited in block_map.iteritems():
orig_var = program.global_block().var(varname) orig_var = program.global_block().var(varname)
if len(splited) == 1: if len(splited) == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册