diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 6d76c1a8d13abb69223b9b335a270c36007410b0..134dbe573a50110f0eeec61816c47f6bcfa82161 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -412,12 +412,7 @@ class DistributeTranspiler: tmpvar = s_prog.global_block().create_var( name=var.name, persistable=var.persistable, type=var.type) else: - tmpvar = s_prog.global_block().create_var( - name=var.name, - persistable=var.persistable, - type=var.type, - dtype=var.dtype, - shape=var.shape) + tmpvar = s_prog.global_block().clone_variable(var) created_var_map[var.name] = tmpvar # 2. rename op outputs @@ -713,18 +708,13 @@ class DistributeTranspiler: varlist = [varlist] for var in varlist: - print("##### deal var: ", var) if var.type == core.VarDesc.VarType.STEP_SCOPES: program.global_block().create_var( name=var.name, persistable=var.persistable, type=var.type) else: - program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) + program.global_block().clone_variable(var) optimize_block.append_op( type=opt_op.type,