提交 bf03a209 编写于 作者: Q qiaolongfei

fix distribute_transpiler

上级 637827a5
......@@ -177,6 +177,7 @@ class DistributeTranspiler:
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
return param_list, grad_list
def _init_splited_vars(self, slice_var_up):
# update these mappings for further transpile:
......@@ -199,8 +200,8 @@ class DistributeTranspiler:
grad_list.append(g)
param_grad_set.add(g.name)
self._update_dist_lookup_table_vars(param_list, grad_list,
self.params_grads)
param_list, grad_list = self._update_dist_lookup_table_vars(
param_list, grad_list, self.params_grads)
if slice_var_up:
# when we slice var up into blocks, we will slice the var according to
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册