From 5d901d00bf9c93225548d707b1c3b79634b801b4 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Thu, 11 Jan 2018 22:41:24 +0800 Subject: [PATCH] update --- .../paddle/v2/fluid/distribute_transpiler.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 75e103cb80..59e74e0d6f 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -459,9 +459,10 @@ class DistributeTranspiler: return pname, splited_param.shape return "", [] - # 1. create vars + # 1. create vars in pserver program to startup program + pserver_vars = pserver_program.global_block().vars created_var_map = dict() - for _, var in pserver_program.global_block().vars.iteritems(): + for _, var in pserver_vars.iteritems(): print("create var for startup", var.name, var.shape) tmpvar = s_prog.global_block().create_var( name=var.name, @@ -469,30 +470,36 @@ class DistributeTranspiler: dtype=var.dtype, shape=var.shape) created_var_map[var.name] = tmpvar - optimize_op_input_var_names = [ - v.name for v in pserver_program.global_block().vars.values() - ] # 2. rename op outputs for op in orig_s_prog.global_block().ops: new_outputs = dict() + # do not append startup op if var is not on this pserver + op_on_pserver = False for key, var in op.outputs.iteritems(): newname, _ = _get_splited_name_and_shape(var.name) if newname: + op_on_pserver = True new_outputs[key] = created_var_map[newname] - else: - new_outputs[key] = var - # do not append startup op if var is not on this pserver - op_on_pserver = False - for _, var in op.outputs.iteritems(): - if var.name in optimize_op_input_var_names: + elif var.name in pserver_vars: op_on_pserver = True - break + new_outputs[key] = pserver_vars[var.name] + + # newname, _ = _get_splited_name_and_shape(var.name) + # if newname: + # print("updating output", newname, created_var_map[newname]) + # new_outputs[key] = created_var_map[newname] + # else: + # print("no update output", key, var) + # new_outputs[key] = var + # if var.name in created_var_map or \ + # newname: + # op_on_pserver = True if op_on_pserver: - # gaussian_random use attr to determine tensor shape if op.type in ["gaussian_random", "fill_constant"]: op.attrs["shape"] = new_outputs["Out"].shape + print("updated shape", op.attrs["shape"]) s_prog.global_block().append_op( type=op.type, inputs=op.inputs, -- GitLab