提交 c7444501 编写于 作者: T typhoonzero

refine distribute transpiler

上级 b41205d9
...@@ -300,6 +300,9 @@ class DistributeTranspiler: ...@@ -300,6 +300,9 @@ class DistributeTranspiler:
pass pass
return orig_shape return orig_shape
def _op_input_var(self, op, varname):
pass
def _is_op_on_pserver(self, endpoint, all_ops, idx): def _is_op_on_pserver(self, endpoint, all_ops, idx):
""" """
Recursively check if the op need to run on current server. Recursively check if the op need to run on current server.
...@@ -309,29 +312,35 @@ class DistributeTranspiler: ...@@ -309,29 +312,35 @@ class DistributeTranspiler:
p.name for p in self.param_grad_ep_mapping[endpoint]["params"] p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
] ]
op = all_ops[idx] op = all_ops[idx]
if op.inputs.has_key("Param"): input_names = set(op.input_names)
if op.inputs["Param"].name in param_names: # TODO(typhoonzero): using Param and Grad input name to identify
# that the operator is an optimization operator, need a better way.
if "Param" in input_names:
if op.input("Param")[0] in param_names:
return True return True
else: else:
for n in param_names: for n in param_names:
if same_or_split_var(n, op.inputs[ if same_or_split_var(n, op.input("Param")[0]) \
"Param"].name) and n != op.inputs["Param"].name: and n != op.input("Param")[0]:
return True return True
return False return False
else: else:
j = idx - 1 j = idx - 1
while j >= 0: while j >= 0:
prev_op = all_ops[j] prev_op = all_ops[j]
prev_output_names = [o.name for o in prev_op.outputs.values()] # prev_output_names = [o.name for o in prev_op.outputs.values()]
prev_input_names = [o.name for o in prev_op.inputs.values()] # prev_input_names = [o.name for o in prev_op.inputs.values()]
# NOTE(typhoonzero): consider list input/output
prev_output_names = prev_op.desc.output_arg_names()
prev_input_names = prev_op.desc.input_arg_names()
found1 = False found1 = False
found2 = False found2 = False
for _, v in op.inputs.iteritems(): for varname in op.desc.input_arg_names():
if v.name in prev_output_names: if varname in prev_output_names:
found1 = self._is_op_on_pserver(endpoint, all_ops, j) found1 = self._is_op_on_pserver(endpoint, all_ops, j)
# later ops may produce output for prev op's next batch use. # later ops may produce output for prev op's next batch use.
for _, v in op.outputs.iteritems(): for varname in op.desc.output_arg_names():
if v.name in prev_input_names: if varname in prev_input_names:
found2 = self._is_op_on_pserver(endpoint, all_ops, j) found2 = self._is_op_on_pserver(endpoint, all_ops, j)
if found1 or found2: if found1 or found2:
return True return True
...@@ -342,11 +351,11 @@ class DistributeTranspiler: ...@@ -342,11 +351,11 @@ class DistributeTranspiler:
new_inputs = dict() new_inputs = dict()
# update param/grad shape first, then other inputs like # update param/grad shape first, then other inputs like
# moment can use the updated shape # moment can use the updated shape
for key, var in opt_op.inputs.iteritems(): for key in opt_op.input_names:
if key == "Grad": if key == "Grad":
grad_block = None grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]: for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var(g.name, var.name): if same_or_split_var(g.name, opt_op.input(key)[0]):
grad_block = g grad_block = g
break break
if not grad_block: if not grad_block:
...@@ -376,7 +385,7 @@ class DistributeTranspiler: ...@@ -376,7 +385,7 @@ class DistributeTranspiler:
# param is already created on global program # param is already created on global program
param_block = None param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]: for p in self.param_grad_ep_mapping[endpoint]["params"]:
if same_or_split_var(p.name, var.name): if same_or_split_var(p.name, opt_op.input(key)):
param_block = p param_block = p
break break
if not param_block: if not param_block:
...@@ -389,11 +398,12 @@ class DistributeTranspiler: ...@@ -389,11 +398,12 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
for key, var in opt_op.inputs.iteritems(): for key in opt_op.input_names:
if key in ["Param", "Grad"]: if key in ["Param", "Grad"]:
continue continue
# update accumulator variable shape # update accumulator variable shape
param_shape = new_inputs["Param"].shape param_shape = new_inputs["Param"].shape
var = program.global_block().vars[opt_op.input(key)]
new_shape = self._get_optimizer_input_shape(opt_op.type, key, new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape) var.shape, param_shape)
tmpvar = program.global_block().create_var( tmpvar = program.global_block().create_var(
...@@ -412,30 +422,46 @@ class DistributeTranspiler: ...@@ -412,30 +422,46 @@ class DistributeTranspiler:
shape=new_shape) shape=new_shape)
# change output's ParamOut variable # change output's ParamOut variable
opt_op.outputs["ParamOut"] = new_inputs["Param"] outputs = self._get_output_map_from_op(program.global_block(), opt_op)
outputs["ParamOut"] = new_inputs["Param"]
program.global_block().append_op( program.global_block().append_op(
type=opt_op.type, type=opt_op.type,
inputs=new_inputs, inputs=new_inputs,
outputs=opt_op.outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op): def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
for _, var in opt_op.inputs.iteritems(): inputs = self._get_input_map_from_op(self.program.global_block().vars,
program.global_block().create_var( opt_op)
name=var.name, for var in inputs.itervalues():
persistable=var.persistable, if type(var) == list:
dtype=var.dtype, varlist = var
shape=var.shape) else:
pserver_program.global_block().create_var( varlist = [var]
name=var.name, for var in varlist:
persistable=var.persistable, program.global_block().create_var(
dtype=var.dtype, name=var.name,
shape=var.shape) persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
try:
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
except ValueError:
# create var if not created yet.
pass
outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
program.global_block().append_op( program.global_block().append_op(
type=opt_op.type, type=opt_op.type,
inputs=opt_op.inputs, inputs=inputs,
outputs=opt_op.outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def get_pserver_program(self, endpoint): def get_pserver_program(self, endpoint):
...@@ -472,7 +498,7 @@ class DistributeTranspiler: ...@@ -472,7 +498,7 @@ class DistributeTranspiler:
self.optimize_ops, idx) self.optimize_ops, idx)
if not is_op_on_pserver: if not is_op_on_pserver:
continue continue
if opt_op.inputs.has_key("Grad"): if "Grad" in opt_op.desc.input_arg_names():
self._append_pserver_ops(optimize_sub_program, pserver_program, self._append_pserver_ops(optimize_sub_program, pserver_program,
opt_op, endpoint) opt_op, endpoint)
else: else:
...@@ -499,6 +525,30 @@ class DistributeTranspiler: ...@@ -499,6 +525,30 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
def _get_input_map_from_op(self, varmap, op):
iomap = dict()
for key in op.input_names:
vars = []
for varname in op.input(key):
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def _get_output_map_from_op(self, varmap, op):
iomap = dict()
for key in op.output_names:
vars = []
for varname in op.output(key):
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def get_startup_program(self, endpoint, pserver_program): def get_startup_program(self, endpoint, pserver_program):
""" """
Get startup program for current parameter server. Get startup program for current parameter server.
...@@ -529,17 +579,21 @@ class DistributeTranspiler: ...@@ -529,17 +579,21 @@ class DistributeTranspiler:
# 2. rename op outputs # 2. rename op outputs
for op in orig_s_prog.global_block().ops: for op in orig_s_prog.global_block().ops:
new_inputs = dict()
new_outputs = dict() new_outputs = dict()
# do not append startup op if var is not on this pserver # do not append startup op if var is not on this pserver
op_on_pserver = False op_on_pserver = False
for key, var in op.outputs.iteritems(): for key in op.output_names:
newname, _ = _get_splited_name_and_shape(var.name) newname, _ = _get_splited_name_and_shape(op.output(key)[0])
if newname: if newname:
op_on_pserver = True op_on_pserver = True
new_outputs[key] = created_var_map[newname] new_outputs[key] = created_var_map[newname]
elif var.name in pserver_vars: elif op.output(key)[0] in pserver_vars:
op_on_pserver = True op_on_pserver = True
new_outputs[key] = pserver_vars[var.name] new_outputs[key] = pserver_vars[op.output(key)[0]]
# most startup program ops have no inputs
new_inputs = self._get_input_map_from_op(pserver_vars, op)
if op_on_pserver: if op_on_pserver:
if op.type in [ if op.type in [
...@@ -548,7 +602,7 @@ class DistributeTranspiler: ...@@ -548,7 +602,7 @@ class DistributeTranspiler:
op.attrs["shape"] = new_outputs["Out"].shape op.attrs["shape"] = new_outputs["Out"].shape
s_prog.global_block().append_op( s_prog.global_block().append_op(
type=op.type, type=op.type,
inputs=op.inputs, inputs=new_inputs,
outputs=new_outputs, outputs=new_outputs,
attrs=op.attrs) attrs=op.attrs)
return s_prog return s_prog
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册