From bf03a2094bce7c542dd64c3a29f445e04c68640b Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 10 Jun 2018 13:24:38 +0800 Subject: [PATCH] fix distribute_transpiler --- python/paddle/fluid/transpiler/distribute_transpiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 27992df462..c7ab300e0f 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -177,6 +177,7 @@ class DistributeTranspiler: dtype=table_grad_var.dtype) for index in range(len(self.pserver_endpoints)) ] + return param_list, grad_list def _init_splited_vars(self, slice_var_up): # update these mappings for further transpile: @@ -199,8 +200,8 @@ class DistributeTranspiler: grad_list.append(g) param_grad_set.add(g.name) - self._update_dist_lookup_table_vars(param_list, grad_list, - self.params_grads) + param_list, grad_list = self._update_dist_lookup_table_vars( + param_list, grad_list, self.params_grads) if slice_var_up: # when we slice var up into blocks, we will slice the var according to -- GitLab