提交 5325313e 编写于 作者: T typhoonzero

debugging shape match

上级 2827607f
...@@ -257,7 +257,45 @@ class DistributeTranspiler: ...@@ -257,7 +257,45 @@ class DistributeTranspiler:
pass pass
return orig_shape return orig_shape
def _append_pserver_ops(self, program, opt_op, endpoint): def _is_op_on_pserver(self, endpoint, all_ops, idx):
"""
Recursively check if the op need to run on current server.
Assume that ops are in the execution order.
"""
param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
]
op = all_ops[idx]
if op.inputs.has_key("Param"):
if op.inputs["Param"].name in param_names:
return True
else:
for n in param_names:
if n.startswith(op.inputs["Param"].name+".block") and \
n != op.inputs["Param"].name:
return True
return False
else:
j = idx - 1
while j >= 0:
prev_op = all_ops[j]
prev_output_names = [o.name for o in prev_op.outputs.values()]
prev_input_names = [o.name for o in prev_op.inputs.values()]
found1 = False
found2 = False
for _, v in op.inputs.iteritems():
if v.name in prev_output_names:
found1 = self._is_op_on_pserver(endpoint, all_ops, j)
# later ops may produce output for prev op's next batch use.
for _, v in op.outputs.iteritems():
if v.name in prev_input_names:
found2 = self._is_op_on_pserver(endpoint, all_ops, j)
if found1 or found2:
return True
j -= 1
return False
def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
new_inputs = dict() new_inputs = dict()
# update param/grad shape first, then other inputs like # update param/grad shape first, then other inputs like
# moment can use the updated shape # moment can use the updated shape
...@@ -321,6 +359,14 @@ class DistributeTranspiler: ...@@ -321,6 +359,14 @@ class DistributeTranspiler:
dtype=var.dtype, dtype=var.dtype,
shape=new_shape) shape=new_shape)
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
# create var in pserver program global block.
# TODO(typhoonzero): put blocks in one program to avoid create two
# variables.
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=new_shape)
# change outputs ParamOut variable # change outputs ParamOut variable
opt_op.outputs["ParamOut"] = new_inputs["Param"] opt_op.outputs["ParamOut"] = new_inputs["Param"]
...@@ -330,13 +376,18 @@ class DistributeTranspiler: ...@@ -330,13 +376,18 @@ class DistributeTranspiler:
outputs=opt_op.outputs, outputs=opt_op.outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, program, opt_op): def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
for _, var in opt_op.inputs.iteritems(): for _, var in opt_op.inputs.iteritems():
program.global_block().create_var( program.global_block().create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
shape=var.shape) shape=var.shape)
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
program.global_block().append_op( program.global_block().append_op(
type=opt_op.type, type=opt_op.type,
inputs=opt_op.inputs, inputs=opt_op.inputs,
...@@ -358,13 +409,18 @@ class DistributeTranspiler: ...@@ -358,13 +409,18 @@ class DistributeTranspiler:
self._clone_var(pserver_program.global_block(), v) self._clone_var(pserver_program.global_block(), v)
# step6 # step6
optimize_sub_program = Program() optimize_sub_program = Program()
for opt_op in optimize_ops: for idx, opt_op in enumerate(optimize_ops):
is_op_on_pserver = self._is_op_on_pserver(endpoint, optimize_ops,
idx)
if not is_op_on_pserver:
continue
if opt_op.inputs.has_key("Grad"): if opt_op.inputs.has_key("Grad"):
# append optimize_op self._append_pserver_ops(optimize_sub_program, pserver_program,
self._append_pserver_ops(optimize_sub_program, opt_op, endpoint) opt_op, endpoint)
else: else:
self._append_pserver_non_opt_ops(optimize_sub_program, opt_op) self._append_pserver_non_opt_ops(optimize_sub_program,
pserver_program, opt_op)
print("****subprogram", optimize_sub_program)
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="recv", type="recv",
inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"] inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"]
...@@ -386,7 +442,7 @@ class DistributeTranspiler: ...@@ -386,7 +442,7 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
def get_startup_program(self, endpoint): def get_startup_program(self, endpoint, pserver_program):
""" """
Get startup program for current parameter server. Get startup program for current parameter server.
Modify operator input variables if there are variables that Modify operator input variables if there are variables that
...@@ -405,13 +461,17 @@ class DistributeTranspiler: ...@@ -405,13 +461,17 @@ class DistributeTranspiler:
# 1. create vars # 1. create vars
created_var_map = dict() created_var_map = dict()
for var in params: for _, var in pserver_program.global_block().vars.iteritems():
print("create var for startup", var.name, var.shape)
tmpvar = s_prog.global_block().create_var( tmpvar = s_prog.global_block().create_var(
name=var.name, name=var.name,
persistable=True, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
shape=var.shape) shape=var.shape)
created_var_map[var.name] = tmpvar created_var_map[var.name] = tmpvar
optimize_op_input_var_names = [
v.name for v in pserver_program.global_block().vars.values()
]
# 2. rename op outputs # 2. rename op outputs
for op in orig_s_prog.global_block().ops: for op in orig_s_prog.global_block().ops:
...@@ -423,13 +483,16 @@ class DistributeTranspiler: ...@@ -423,13 +483,16 @@ class DistributeTranspiler:
else: else:
new_outputs[key] = var new_outputs[key] = var
# do not append startup op if var is not on this pserver # do not append startup op if var is not on this pserver
var_on_pserver = False op_on_pserver = False
for _, var in new_outputs.iteritems(): for _, var in op.outputs.iteritems():
if var.name in created_var_map: if var.name in optimize_op_input_var_names:
var_on_pserver = True op_on_pserver = True
if var_on_pserver: break
if op_on_pserver:
# gaussian_random use attr to determine tensor shape # gaussian_random use attr to determine tensor shape
op.attrs["shape"] = new_outputs["Out"].shape if op.type in ["gaussian_random", "fill_constant"]:
op.attrs["shape"] = new_outputs["Out"].shape
s_prog.global_block().append_op( s_prog.global_block().append_op(
type=op.type, type=op.type,
inputs=op.inputs, inputs=op.inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册