提交 2cfb2928 编写于 作者: T typhoonzero

Fix develop dist transpiler bug

上级 caf9a09d
......@@ -191,7 +191,6 @@ class DistributeTranspiler:
for b in param_blocks:
varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints)
......@@ -230,21 +229,6 @@ class DistributeTranspiler:
outputs={"Out": [orig_param]},
attrs={"axis": 0})
self.lr_param_mapping = self._create_lr_param_mapping()
def _create_lr_param_mapping(self):
lr_mapping = dict()
for _, opt_op in enumerate(self.optimize_ops):
if not opt_op.inputs or not opt_op.inputs.has_key("LearningRate") \
or not opt_op.inputs.has_key("Param"):
continue
lr = opt_op.inputs["LearningRate"].name
param = opt_op.inputs["Param"].name
if not lr_mapping.has_key(lr):
lr_mapping.update({lr: list()})
lr_mapping[lr].append(param)
return lr_mapping
def _create_vars_from_blocklist(self, program, block_list):
# Create respective variables using the block_list
block_map = dict()
......@@ -369,18 +353,19 @@ class DistributeTranspiler:
pass
return orig_shape
def _fetch_var_names(self, param_dict):
res = []
if not param_dict:
return res
for _, values in param_dict.iteritems():
if not isinstance(values, list):
values = [values]
res += [v.name for v in values]
return res
# def _fetch_var_names(self, param_dict):
# res = []
# if not param_dict:
# return res
# for _, values in param_dict.iteritems():
# if not isinstance(values, list):
# values = [values]
# res += [v.name for v in values]
# return res
def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
......@@ -395,11 +380,11 @@ class DistributeTranspiler:
# do not append this op if current endpoint
# is not dealing with this grad block
return
merged_var = program.global_block().vars[grad_block.name]
merged_var = pserver_block.vars[grad_block.name]
# append merging ops if trainers > 1
if self.trainers > 1:
vars2merge = self._create_var_for_trainers(
program.global_block(), grad_block, self.trainers)
pserver_block, grad_block, self.trainers)
optimize_block.append_op(
type="sum",
inputs={"X": vars2merge},
......@@ -419,29 +404,27 @@ class DistributeTranspiler:
break
if not param_block:
return
tmpvar = program.global_block().create_var(
tmpvar = pserver_block.create_var(
name=param_block.name,
persistable=True,
dtype=param_block.dtype,
shape=param_block.shape)
new_inputs[key] = tmpvar
elif key == "LearningRate":
# leraning rate variable has already be created by non-optimize op,
# don't create it once again.
new_inputs[key] = program.global_block().vars[opt_op.input(key)[
0]]
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
for key in opt_op.input_names:
new_shape = None
if key in ["Param", "Grad", "LearningRate"]:
continue
var = program.global_block().vars[opt_op.input(key)[0]]
var = self.program.global_block().vars[opt_op.input(key)[0]]
# update accumulator variable shape
param_shape = new_inputs["Param"].shape
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape)
tmpvar = program.global_block().create_var(
tmpvar = pserver_block.create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
......@@ -449,11 +432,14 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar
# change output's ParamOut variable
outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
opt_op.outputs["ParamOut"] = new_inputs["Param"]
optimize_block.append_op(
type=opt_op.type,
inputs=new_inputs,
outputs=opt_op.outputs,
outputs=outputs,
attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
......@@ -497,11 +483,16 @@ class DistributeTranspiler:
# If one op's input is another op's output or
# one op's output is another op's input, we say
# the two operator is connected.
op1_input_names = self._fetch_var_names(op1.inputs)
op1_output_names = self._fetch_var_names(op1.outputs)
# op1_input_names = self._fetch_var_names(op1.inputs)
# op1_output_names = self._fetch_var_names(op1.outputs)
op1_input_names = op1.desc.input_arg_names()
op1_output_names = op1.desc.output_arg_names()
# op2_input_names = self._fetch_var_names(op2.inputs)
# op2_output_names = self._fetch_var_names(op2.outputs)
op2_input_names = op2.desc.input_arg_names()
op2_output_names = op2.desc.output_arg_names()
op2_input_names = self._fetch_var_names(op2.inputs)
op2_output_names = self._fetch_var_names(op2.outputs)
if set(op1_output_names) & set(op2_input_names) or \
set(op1_input_names) & set(op2_output_names):
return True
......@@ -521,8 +512,8 @@ class DistributeTranspiler:
def _is_opt_op(self, op):
# NOTE: It's a HACK implement.
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
if op.inputs and op.inputs.has_key("Param") \
and op.inputs.has_key("LearningRate"):
if "Param" in op.input_names and \
"LearningRate" in op.input_names:
return True
return False
......@@ -530,12 +521,12 @@ class DistributeTranspiler:
param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
]
if op.inputs["Param"].name in param_names:
if op.input("Param") in param_names:
return True
else:
for n in param_names:
param = op.inputs["Param"].name
if same_or_split_var(n, param) and n != op.inputs["Param"].name:
param = op.input("Param")[0]
if same_or_split_var(n, param) and n != param:
return True
return False
return False
......@@ -564,7 +555,6 @@ class DistributeTranspiler:
persistable=True,
dtype=v.dtype,
shape=v.shape)
# step6
optimize_block = pserver_program.create_block(0)
# step 6.1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册