diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 75e103cb80c39d0378582b7d4c0bd82cfcc6fa75..59e74e0d6f206910ade8b8b7617428390fbf23bd 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -459,9 +459,10 @@ class DistributeTranspiler: return pname, splited_param.shape return "", [] - # 1. create vars + # 1. create vars in pserver program to startup program + pserver_vars = pserver_program.global_block().vars created_var_map = dict() - for _, var in pserver_program.global_block().vars.iteritems(): + for _, var in pserver_vars.iteritems(): print("create var for startup", var.name, var.shape) tmpvar = s_prog.global_block().create_var( name=var.name, @@ -469,30 +470,36 @@ class DistributeTranspiler: dtype=var.dtype, shape=var.shape) created_var_map[var.name] = tmpvar - optimize_op_input_var_names = [ - v.name for v in pserver_program.global_block().vars.values() - ] # 2. rename op outputs for op in orig_s_prog.global_block().ops: new_outputs = dict() + # do not append startup op if var is not on this pserver + op_on_pserver = False for key, var in op.outputs.iteritems(): newname, _ = _get_splited_name_and_shape(var.name) if newname: + op_on_pserver = True new_outputs[key] = created_var_map[newname] - else: - new_outputs[key] = var - # do not append startup op if var is not on this pserver - op_on_pserver = False - for _, var in op.outputs.iteritems(): - if var.name in optimize_op_input_var_names: + elif var.name in pserver_vars: op_on_pserver = True - break + new_outputs[key] = pserver_vars[var.name] + + # newname, _ = _get_splited_name_and_shape(var.name) + # if newname: + # print("updating output", newname, created_var_map[newname]) + # new_outputs[key] = created_var_map[newname] + # else: + # print("no update output", key, var) + # new_outputs[key] = var + # if var.name in created_var_map or \ + # newname: + # op_on_pserver = True if op_on_pserver: - # gaussian_random use attr to determine tensor shape if op.type in ["gaussian_random", "fill_constant"]: op.attrs["shape"] = new_outputs["Out"].shape + print("updated shape", op.attrs["shape"]) s_prog.global_block().append_op( type=op.type, inputs=op.inputs,