提交 c7444501 编写于 作者: T typhoonzero

refine distribute transpiler

上级 b41205d9
......@@ -300,6 +300,9 @@ class DistributeTranspiler:
pass
return orig_shape
def _op_input_var(self, op, varname):
pass
def _is_op_on_pserver(self, endpoint, all_ops, idx):
"""
Recursively check if the op need to run on current server.
......@@ -309,29 +312,35 @@ class DistributeTranspiler:
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
]
op = all_ops[idx]
if op.inputs.has_key("Param"):
if op.inputs["Param"].name in param_names:
input_names = set(op.input_names)
# TODO(typhoonzero): using Param and Grad input name to identify
# that the operator is an optimization operator, need a better way.
if "Param" in input_names:
if op.input("Param")[0] in param_names:
return True
else:
for n in param_names:
if same_or_split_var(n, op.inputs[
"Param"].name) and n != op.inputs["Param"].name:
if same_or_split_var(n, op.input("Param")[0]) \
and n != op.input("Param")[0]:
return True
return False
else:
j = idx - 1
while j >= 0:
prev_op = all_ops[j]
prev_output_names = [o.name for o in prev_op.outputs.values()]
prev_input_names = [o.name for o in prev_op.inputs.values()]
# prev_output_names = [o.name for o in prev_op.outputs.values()]
# prev_input_names = [o.name for o in prev_op.inputs.values()]
# NOTE(typhoonzero): consider list input/output
prev_output_names = prev_op.desc.output_arg_names()
prev_input_names = prev_op.desc.input_arg_names()
found1 = False
found2 = False
for _, v in op.inputs.iteritems():
if v.name in prev_output_names:
for varname in op.desc.input_arg_names():
if varname in prev_output_names:
found1 = self._is_op_on_pserver(endpoint, all_ops, j)
# later ops may produce output for prev op's next batch use.
for _, v in op.outputs.iteritems():
if v.name in prev_input_names:
for varname in op.desc.output_arg_names():
if varname in prev_input_names:
found2 = self._is_op_on_pserver(endpoint, all_ops, j)
if found1 or found2:
return True
......@@ -342,11 +351,11 @@ class DistributeTranspiler:
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
for key, var in opt_op.inputs.iteritems():
for key in opt_op.input_names:
if key == "Grad":
grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var(g.name, var.name):
if same_or_split_var(g.name, opt_op.input(key)[0]):
grad_block = g
break
if not grad_block:
......@@ -376,7 +385,7 @@ class DistributeTranspiler:
# param is already created on global program
param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]:
if same_or_split_var(p.name, var.name):
if same_or_split_var(p.name, opt_op.input(key)):
param_block = p
break
if not param_block:
......@@ -389,11 +398,12 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar
for key, var in opt_op.inputs.iteritems():
for key in opt_op.input_names:
if key in ["Param", "Grad"]:
continue
# update accumulator variable shape
param_shape = new_inputs["Param"].shape
var = program.global_block().vars[opt_op.input(key)]
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape)
tmpvar = program.global_block().create_var(
......@@ -412,30 +422,46 @@ class DistributeTranspiler:
shape=new_shape)
# change output's ParamOut variable
opt_op.outputs["ParamOut"] = new_inputs["Param"]
outputs = self._get_output_map_from_op(program.global_block(), opt_op)
outputs["ParamOut"] = new_inputs["Param"]
program.global_block().append_op(
type=opt_op.type,
inputs=new_inputs,
outputs=opt_op.outputs,
outputs=outputs,
attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
# Append the ops for parameters that do not need to be optimized/updated
for _, var in opt_op.inputs.iteritems():
inputs = self._get_input_map_from_op(self.program.global_block().vars,
opt_op)
for var in inputs.itervalues():
if type(var) == list:
varlist = var
else:
varlist = [var]
for var in varlist:
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
try:
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
except ValueError:
# create var if not created yet.
pass
outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
program.global_block().append_op(
type=opt_op.type,
inputs=opt_op.inputs,
outputs=opt_op.outputs,
inputs=inputs,
outputs=outputs,
attrs=opt_op.attrs)
def get_pserver_program(self, endpoint):
......@@ -472,7 +498,7 @@ class DistributeTranspiler:
self.optimize_ops, idx)
if not is_op_on_pserver:
continue
if opt_op.inputs.has_key("Grad"):
if "Grad" in opt_op.desc.input_arg_names():
self._append_pserver_ops(optimize_sub_program, pserver_program,
opt_op, endpoint)
else:
......@@ -499,6 +525,30 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp()
return pserver_program
def _get_input_map_from_op(self, varmap, op):
iomap = dict()
for key in op.input_names:
vars = []
for varname in op.input(key):
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def _get_output_map_from_op(self, varmap, op):
iomap = dict()
for key in op.output_names:
vars = []
for varname in op.output(key):
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def get_startup_program(self, endpoint, pserver_program):
"""
Get startup program for current parameter server.
......@@ -529,17 +579,21 @@ class DistributeTranspiler:
# 2. rename op outputs
for op in orig_s_prog.global_block().ops:
new_inputs = dict()
new_outputs = dict()
# do not append startup op if var is not on this pserver
op_on_pserver = False
for key, var in op.outputs.iteritems():
newname, _ = _get_splited_name_and_shape(var.name)
for key in op.output_names:
newname, _ = _get_splited_name_and_shape(op.output(key)[0])
if newname:
op_on_pserver = True
new_outputs[key] = created_var_map[newname]
elif var.name in pserver_vars:
elif op.output(key)[0] in pserver_vars:
op_on_pserver = True
new_outputs[key] = pserver_vars[var.name]
new_outputs[key] = pserver_vars[op.output(key)[0]]
# most startup program ops have no inputs
new_inputs = self._get_input_map_from_op(pserver_vars, op)
if op_on_pserver:
if op.type in [
......@@ -548,7 +602,7 @@ class DistributeTranspiler:
op.attrs["shape"] = new_outputs["Out"].shape
s_prog.global_block().append_op(
type=op.type,
inputs=op.inputs,
inputs=new_inputs,
outputs=new_outputs,
attrs=op.attrs)
return s_prog
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册