diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 246d6cb6c86bb6e8ba2c985296aa6a375e9e79f8..3aa89cb0c40111ae218b0e34954e3b1d9eb5ffa8 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -453,8 +453,7 @@ class DistributeTranspiler: if self.has_distributed_lookup_table: pserver_index = self.pserver_endpoints.index(endpoint) table_opt_block = self._create_table_optimize_block( - pserver_index, pserver_program, opt_state_block or - pserver_program.global_block()) + pserver_index, pserver_program, pre_block_idx) prefetch_block = self._create_prefetch_block( pserver_index, pserver_program, table_opt_block) @@ -665,7 +664,7 @@ class DistributeTranspiler: return prefetch_block def _create_table_optimize_block(self, pserver_index, pserver_program, - append_block): + pre_block_idx): def _clone_var(block, var, persistable=True): assert isinstance(var, Variable) return block.create_var( @@ -702,7 +701,7 @@ class DistributeTranspiler: op for op in self.optimize_ops if op.input("Param")[0] == self.table_name ][0] - table_opt_block = pserver_program.create_block(append_block.idx) + table_opt_block = pserver_program.create_block(pre_block_idx) # only support sgd now assert table_opt_op.type == "sgd"