提交 16027ea1 编写于 作者: Q qiaolongfei

use block.clone_variable instead of _clone_var

上级 85d0301a
...@@ -689,15 +689,6 @@ class DistributeTranspiler: ...@@ -689,15 +689,6 @@ class DistributeTranspiler:
def _create_table_optimize_block(self, pserver_index, pserver_program, def _create_table_optimize_block(self, pserver_index, pserver_program,
pre_block_idx, grad_to_block_id): pre_block_idx, grad_to_block_id):
def _clone_var(block, var, persistable=True):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=persistable)
# STEP: create table optimize block # STEP: create table optimize block
# create table param and grad var in pserver program # create table param and grad var in pserver program
origin_param_var = self.origin_program.global_block().vars[ origin_param_var = self.origin_program.global_block().vars[
...@@ -708,11 +699,11 @@ class DistributeTranspiler: ...@@ -708,11 +699,11 @@ class DistributeTranspiler:
dtype=origin_param_var.dtype, dtype=origin_param_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS, type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True) persistable=True)
grad_var = _clone_var( # parameter must be selected rows
pserver_program.global_block(), param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
grad_var = pserver_program.global_block().clone_variable(
self.origin_program.global_block().vars[grad_var_name( self.origin_program.global_block().vars[grad_var_name(
self.table_name)], self.table_name)])
persistable=False)
# create table optimize block in pserver program # create table optimize block in pserver program
table_opt_op = [ table_opt_op = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册