From 39277e9282294dc18b4c2b93aa000a15b58bea5f Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 4 Apr 2018 14:55:28 +0800 Subject: [PATCH] fix transpiler condition op in optimize --- python/paddle/fluid/distribute_transpiler.py | 32 ++++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 9311fc9904e..6d76c1a8d13 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -408,11 +408,16 @@ class DistributeTranspiler: pserver_vars = pserver_program.global_block().vars created_var_map = dict() for _, var in pserver_vars.iteritems(): - tmpvar = s_prog.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) + if var.type == core.VarDesc.VarType.STEP_SCOPES: + tmpvar = s_prog.global_block().create_var( + name=var.name, persistable=var.persistable, type=var.type) + else: + tmpvar = s_prog.global_block().create_var( + name=var.name, + persistable=var.persistable, + type=var.type, + dtype=var.dtype, + shape=var.shape) created_var_map[var.name] = tmpvar # 2. rename op outputs @@ -708,11 +713,18 @@ class DistributeTranspiler: varlist = [varlist] for var in varlist: - program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) + print("##### deal var: ", var) + if var.type == core.VarDesc.VarType.STEP_SCOPES: + program.global_block().create_var( + name=var.name, + persistable=var.persistable, + type=var.type) + else: + program.global_block().create_var( + name=var.name, + persistable=var.persistable, + dtype=var.dtype, + shape=var.shape) optimize_block.append_op( type=opt_op.type, -- GitLab