提交 18e0b730 编写于 作者: Q qiaolongfei

fix _create_table_optimize_block

上级 39f6274e
...@@ -453,8 +453,7 @@ class DistributeTranspiler: ...@@ -453,8 +453,7 @@ class DistributeTranspiler:
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint) pserver_index = self.pserver_endpoints.index(endpoint)
table_opt_block = self._create_table_optimize_block( table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, opt_state_block or pserver_index, pserver_program, pre_block_idx)
pserver_program.global_block())
prefetch_block = self._create_prefetch_block( prefetch_block = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block) pserver_index, pserver_program, table_opt_block)
...@@ -665,7 +664,7 @@ class DistributeTranspiler: ...@@ -665,7 +664,7 @@ class DistributeTranspiler:
return prefetch_block return prefetch_block
def _create_table_optimize_block(self, pserver_index, pserver_program, def _create_table_optimize_block(self, pserver_index, pserver_program,
append_block): pre_block_idx):
def _clone_var(block, var, persistable=True): def _clone_var(block, var, persistable=True):
assert isinstance(var, Variable) assert isinstance(var, Variable)
return block.create_var( return block.create_var(
...@@ -702,7 +701,7 @@ class DistributeTranspiler: ...@@ -702,7 +701,7 @@ class DistributeTranspiler:
op for op in self.optimize_ops op for op in self.optimize_ops
if op.input("Param")[0] == self.table_name if op.input("Param")[0] == self.table_name
][0] ][0]
table_opt_block = pserver_program.create_block(append_block.idx) table_opt_block = pserver_program.create_block(pre_block_idx)
# only support sgd now # only support sgd now
assert table_opt_op.type == "sgd" assert table_opt_op.type == "sgd"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册