diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index ed4158bc4c91b4db0143c8d607a5ea9220e528ff..bf06f4c2ca73c8cfd79f39eb37499f07d854543c 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -177,6 +177,7 @@ class DistributeTranspiler: dtype=table_grad_var.dtype) for index in range(len(self.pserver_endpoints)) ] + return param_list, grad_list def _init_splited_vars(self, slice_var_up): # update these mappings for further transpile: @@ -199,8 +200,8 @@ class DistributeTranspiler: grad_list.append(g) param_grad_set.add(g.name) - self._update_dist_lookup_table_vars(param_list, grad_list, - self.params_grads) + param_list, grad_list = self._update_dist_lookup_table_vars( + param_list, grad_list, self.params_grads) if slice_var_up: # when we slice var up into blocks, we will slice the var according to