diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 9311fc9904eb730aa56e94a4e45a1479a67df641..6d76c1a8d13abb69223b9b335a270c36007410b0 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -408,11 +408,16 @@ class DistributeTranspiler: pserver_vars = pserver_program.global_block().vars created_var_map = dict() for _, var in pserver_vars.iteritems(): - tmpvar = s_prog.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) + if var.type == core.VarDesc.VarType.STEP_SCOPES: + tmpvar = s_prog.global_block().create_var( + name=var.name, persistable=var.persistable, type=var.type) + else: + tmpvar = s_prog.global_block().create_var( + name=var.name, + persistable=var.persistable, + type=var.type, + dtype=var.dtype, + shape=var.shape) created_var_map[var.name] = tmpvar # 2. rename op outputs @@ -708,11 +713,18 @@ class DistributeTranspiler: varlist = [varlist] for var in varlist: - program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) + print("##### deal var: ", var) + if var.type == core.VarDesc.VarType.STEP_SCOPES: + program.global_block().create_var( + name=var.name, + persistable=var.persistable, + type=var.type) + else: + program.global_block().create_var( + name=var.name, + persistable=var.persistable, + dtype=var.dtype, + shape=var.shape) optimize_block.append_op( type=opt_op.type,