提交 2827607f 编写于 作者: Y yi.wu

fix startup program shape

上级 6fa56b9d
...@@ -365,7 +365,6 @@ class DistributeTranspiler: ...@@ -365,7 +365,6 @@ class DistributeTranspiler:
else: else:
self._append_pserver_non_opt_ops(optimize_sub_program, opt_op) self._append_pserver_non_opt_ops(optimize_sub_program, opt_op)
print("####", optimize_sub_program)
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="recv", type="recv",
inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"] inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"]
...@@ -407,7 +406,6 @@ class DistributeTranspiler: ...@@ -407,7 +406,6 @@ class DistributeTranspiler:
# 1. create vars # 1. create vars
created_var_map = dict() created_var_map = dict()
for var in params: for var in params:
print("%%%% append var", var.name, var.shape)
tmpvar = s_prog.global_block().create_var( tmpvar = s_prog.global_block().create_var(
name=var.name, name=var.name,
persistable=True, persistable=True,
...@@ -430,6 +428,8 @@ class DistributeTranspiler: ...@@ -430,6 +428,8 @@ class DistributeTranspiler:
if var.name in created_var_map: if var.name in created_var_map:
var_on_pserver = True var_on_pserver = True
if var_on_pserver: if var_on_pserver:
# gaussian_random use attr to determine tensor shape
op.attrs["shape"] = new_outputs["Out"].shape
s_prog.global_block().append_op( s_prog.global_block().append_op(
type=op.type, type=op.type,
inputs=op.inputs, inputs=op.inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册