提交 5d901d00 编写于 作者: Y yi.wu

update

上级 5325313e
...@@ -459,9 +459,10 @@ class DistributeTranspiler: ...@@ -459,9 +459,10 @@ class DistributeTranspiler:
return pname, splited_param.shape return pname, splited_param.shape
return "", [] return "", []
# 1. create vars # 1. create vars in pserver program to startup program
pserver_vars = pserver_program.global_block().vars
created_var_map = dict() created_var_map = dict()
for _, var in pserver_program.global_block().vars.iteritems(): for _, var in pserver_vars.iteritems():
print("create var for startup", var.name, var.shape) print("create var for startup", var.name, var.shape)
tmpvar = s_prog.global_block().create_var( tmpvar = s_prog.global_block().create_var(
name=var.name, name=var.name,
...@@ -469,30 +470,36 @@ class DistributeTranspiler: ...@@ -469,30 +470,36 @@ class DistributeTranspiler:
dtype=var.dtype, dtype=var.dtype,
shape=var.shape) shape=var.shape)
created_var_map[var.name] = tmpvar created_var_map[var.name] = tmpvar
optimize_op_input_var_names = [
v.name for v in pserver_program.global_block().vars.values()
]
# 2. rename op outputs # 2. rename op outputs
for op in orig_s_prog.global_block().ops: for op in orig_s_prog.global_block().ops:
new_outputs = dict() new_outputs = dict()
# do not append startup op if var is not on this pserver
op_on_pserver = False
for key, var in op.outputs.iteritems(): for key, var in op.outputs.iteritems():
newname, _ = _get_splited_name_and_shape(var.name) newname, _ = _get_splited_name_and_shape(var.name)
if newname: if newname:
op_on_pserver = True
new_outputs[key] = created_var_map[newname] new_outputs[key] = created_var_map[newname]
else: elif var.name in pserver_vars:
new_outputs[key] = var
# do not append startup op if var is not on this pserver
op_on_pserver = False
for _, var in op.outputs.iteritems():
if var.name in optimize_op_input_var_names:
op_on_pserver = True op_on_pserver = True
break new_outputs[key] = pserver_vars[var.name]
# newname, _ = _get_splited_name_and_shape(var.name)
# if newname:
# print("updating output", newname, created_var_map[newname])
# new_outputs[key] = created_var_map[newname]
# else:
# print("no update output", key, var)
# new_outputs[key] = var
# if var.name in created_var_map or \
# newname:
# op_on_pserver = True
if op_on_pserver: if op_on_pserver:
# gaussian_random use attr to determine tensor shape
if op.type in ["gaussian_random", "fill_constant"]: if op.type in ["gaussian_random", "fill_constant"]:
op.attrs["shape"] = new_outputs["Out"].shape op.attrs["shape"] = new_outputs["Out"].shape
print("updated shape", op.attrs["shape"])
s_prog.global_block().append_op( s_prog.global_block().append_op(
type=op.type, type=op.type,
inputs=op.inputs, inputs=op.inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册