提交 5325313e 编写于 作者: T typhoonzero

debugging shape match

上级 2827607f
......@@ -257,7 +257,45 @@ class DistributeTranspiler:
pass
return orig_shape
def _append_pserver_ops(self, program, opt_op, endpoint):
def _is_op_on_pserver(self, endpoint, all_ops, idx):
"""
Recursively check if the op need to run on current server.
Assume that ops are in the execution order.
"""
param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
]
op = all_ops[idx]
if op.inputs.has_key("Param"):
if op.inputs["Param"].name in param_names:
return True
else:
for n in param_names:
if n.startswith(op.inputs["Param"].name+".block") and \
n != op.inputs["Param"].name:
return True
return False
else:
j = idx - 1
while j >= 0:
prev_op = all_ops[j]
prev_output_names = [o.name for o in prev_op.outputs.values()]
prev_input_names = [o.name for o in prev_op.inputs.values()]
found1 = False
found2 = False
for _, v in op.inputs.iteritems():
if v.name in prev_output_names:
found1 = self._is_op_on_pserver(endpoint, all_ops, j)
# later ops may produce output for prev op's next batch use.
for _, v in op.outputs.iteritems():
if v.name in prev_input_names:
found2 = self._is_op_on_pserver(endpoint, all_ops, j)
if found1 or found2:
return True
j -= 1
return False
def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
......@@ -321,6 +359,14 @@ class DistributeTranspiler:
dtype=var.dtype,
shape=new_shape)
new_inputs[key] = tmpvar
# create var in pserver program global block.
# TODO(typhoonzero): put blocks in one program to avoid create two
# variables.
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=new_shape)
# change outputs ParamOut variable
opt_op.outputs["ParamOut"] = new_inputs["Param"]
......@@ -330,13 +376,18 @@ class DistributeTranspiler:
outputs=opt_op.outputs,
attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, program, opt_op):
def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
for _, var in opt_op.inputs.iteritems():
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
program.global_block().append_op(
type=opt_op.type,
inputs=opt_op.inputs,
......@@ -358,13 +409,18 @@ class DistributeTranspiler:
self._clone_var(pserver_program.global_block(), v)
# step6
optimize_sub_program = Program()
for opt_op in optimize_ops:
for idx, opt_op in enumerate(optimize_ops):
is_op_on_pserver = self._is_op_on_pserver(endpoint, optimize_ops,
idx)
if not is_op_on_pserver:
continue
if opt_op.inputs.has_key("Grad"):
# append optimize_op
self._append_pserver_ops(optimize_sub_program, opt_op, endpoint)
self._append_pserver_ops(optimize_sub_program, pserver_program,
opt_op, endpoint)
else:
self._append_pserver_non_opt_ops(optimize_sub_program, opt_op)
self._append_pserver_non_opt_ops(optimize_sub_program,
pserver_program, opt_op)
print("****subprogram", optimize_sub_program)
pserver_program.global_block().append_op(
type="recv",
inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"]
......@@ -386,7 +442,7 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp()
return pserver_program
def get_startup_program(self, endpoint):
def get_startup_program(self, endpoint, pserver_program):
"""
Get startup program for current parameter server.
Modify operator input variables if there are variables that
......@@ -405,13 +461,17 @@ class DistributeTranspiler:
# 1. create vars
created_var_map = dict()
for var in params:
for _, var in pserver_program.global_block().vars.iteritems():
print("create var for startup", var.name, var.shape)
tmpvar = s_prog.global_block().create_var(
name=var.name,
persistable=True,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
created_var_map[var.name] = tmpvar
optimize_op_input_var_names = [
v.name for v in pserver_program.global_block().vars.values()
]
# 2. rename op outputs
for op in orig_s_prog.global_block().ops:
......@@ -423,13 +483,16 @@ class DistributeTranspiler:
else:
new_outputs[key] = var
# do not append startup op if var is not on this pserver
var_on_pserver = False
for _, var in new_outputs.iteritems():
if var.name in created_var_map:
var_on_pserver = True
if var_on_pserver:
op_on_pserver = False
for _, var in op.outputs.iteritems():
if var.name in optimize_op_input_var_names:
op_on_pserver = True
break
if op_on_pserver:
# gaussian_random use attr to determine tensor shape
op.attrs["shape"] = new_outputs["Out"].shape
if op.type in ["gaussian_random", "fill_constant"]:
op.attrs["shape"] = new_outputs["Out"].shape
s_prog.global_block().append_op(
type=op.type,
inputs=op.inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册