From a16a872783d52d9ba7d32d53848e95cc4ccaefd6 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 4 Apr 2018 15:56:21 +0800 Subject: [PATCH] update --- python/paddle/fluid/distribute_transpiler.py | 14 ++----------- python/paddle/fluid/framework.py | 21 +++++++++++++------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 134dbe573a5..31bedb592f1 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -408,11 +408,7 @@ class DistributeTranspiler: pserver_vars = pserver_program.global_block().vars created_var_map = dict() for _, var in pserver_vars.iteritems(): - if var.type == core.VarDesc.VarType.STEP_SCOPES: - tmpvar = s_prog.global_block().create_var( - name=var.name, persistable=var.persistable, type=var.type) - else: - tmpvar = s_prog.global_block().clone_variable(var) + tmpvar = s_prog.global_block().clone_variable(var) created_var_map[var.name] = tmpvar # 2. rename op outputs @@ -708,13 +704,7 @@ class DistributeTranspiler: varlist = [varlist] for var in varlist: - if var.type == core.VarDesc.VarType.STEP_SCOPES: - program.global_block().create_var( - name=var.name, - persistable=var.persistable, - type=var.type) - else: - program.global_block().clone_variable(var) + program.global_block().clone_variable(var) optimize_block.append_op( type=opt_op.type, diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index e15456bfc08..39d4017861f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -946,13 +946,20 @@ class Block(object): The new variable cloned from 'var' in current block. """ assert isinstance(var, Variable) - return self.create_var( - name=var.name, - shape=var.shape, - dtype=var.dtype, - type=var.type, - lod_level=var.lod_level, - persistable=True) + ret_var = None + # make STEP_SCOPES var can be safely cloned. + if var.type == core.VarDesc.VarType.STEP_SCOPES: + ret_var = self.create_var( + name=var.name, persistable=var.persistable, type=var.type) + else: + ret_var = self.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=True) + return ret_var class Program(object): -- GitLab