提交 bf03a209 编写于 作者: Q qiaolongfei

fix distribute_transpiler

上级 637827a5
...@@ -177,6 +177,7 @@ class DistributeTranspiler: ...@@ -177,6 +177,7 @@ class DistributeTranspiler:
dtype=table_grad_var.dtype) dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints)) for index in range(len(self.pserver_endpoints))
] ]
return param_list, grad_list
def _init_splited_vars(self, slice_var_up): def _init_splited_vars(self, slice_var_up):
# update these mappings for further transpile: # update these mappings for further transpile:
...@@ -199,8 +200,8 @@ class DistributeTranspiler: ...@@ -199,8 +200,8 @@ class DistributeTranspiler:
grad_list.append(g) grad_list.append(g)
param_grad_set.add(g.name) param_grad_set.add(g.name)
self._update_dist_lookup_table_vars(param_list, grad_list, param_list, grad_list = self._update_dist_lookup_table_vars(
self.params_grads) param_list, grad_list, self.params_grads)
if slice_var_up: if slice_var_up:
# when we slice var up into blocks, we will slice the var according to # when we slice var up into blocks, we will slice the var according to
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册