提交 50a02adf 编写于 作者: T typhoonzero

transpile program ok

上级 9c0b1cf1
......@@ -114,12 +114,15 @@ class DistributeTranspiler:
# step3
send_inputs = []
send_outputs = []
for _, splited in grad_var_mapping.iteritems():
send_inputs.extend(splited)
for b in grad_blocks: # append by order
varname, block_id, _ = b.split(":")
send_inputs.append(grad_var_mapping[varname][int(block_id)])
param_var_mapping = self._create_vars_from_blocklist(program,
param_blocks)
for _, splited in param_var_mapping.iteritems():
send_outputs.extend(splited)
for b in param_blocks:
varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var, eplist is of the same
# order of send_inputs.
eplist = split_method(send_inputs, pserver_endpoints)
......@@ -243,8 +246,37 @@ class DistributeTranspiler:
var_list.append(var_each)
return var_list
def _append_pserver_ops(self, opt_op, endpoint):
def _get_optimizer_input_shape(self, op_type, varkey, orig_shape,
param_shape):
"""
Returns the shape for optimizer inputs that need to be reshaped when
Param and Grad is splited to multiple servers.
"""
# HACK(typhoonzero): Should use functions of corresponding optimizer in
# optimizer.py to get the shape, do not bind this in the transpiler.
if op_type == "adam":
if varkey in ["Moment1", "Moment2"]:
return param_shape
elif op_type == "adagrad":
if varkey == "Moment":
return param_shape
elif op_type == "adamax":
if varkey in ["Moment", "InfNorm"]:
return param_shape
elif op_type == "momentum":
if varkey == "Velocity":
return param_shape
elif op_type == "":
if varkey == "Moment":
return param_shape
elif op_type == "sgd":
pass
return orig_shape
def _append_pserver_ops(self, program, opt_op, endpoint):
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
for key, var in opt_op.inputs.iteritems():
if key == "Grad":
grad_block = None
......@@ -256,7 +288,7 @@ class DistributeTranspiler:
# do not append this op if current endpoint
# is not dealing with this grad block
return
merged_var = optimize_sub_program.global_block().create_var(
merged_var = program.global_block().create_var(
name=grad_block.name,
persistable=grad_block.persistable,
dtype=grad_block.dtype,
......@@ -264,13 +296,12 @@ class DistributeTranspiler:
# append merging ops if trainers > 1
if self.trainers > 1:
vars2merge = self._create_var_for_trainers(
optimize_sub_program.global_block(), grad_block,
self.trainers)
optimize_sub_program.global_block().append_op(
program.global_block(), grad_block, self.trainers)
program.global_block().append_op(
type="sum",
inputs={"X": vars2merge},
outputs={"Out": merged_var})
optimize_sub_program.global_block().append_op(
program.global_block().append_op(
type="scale",
inputs={"X": merged_var},
outputs={"Out": merged_var},
......@@ -285,37 +316,45 @@ class DistributeTranspiler:
break
if not param_block:
return
tmpvar = optimize_sub_program.global_block().create_var(
tmpvar = program.global_block().create_var(
name=param_block.name,
persistable=param_block.persistable,
dtype=param_block.dtype,
shape=param_block.shape)
new_inputs[key] = tmpvar
else:
tmpvar = optimize_sub_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
new_inputs[key] = tmpvar
for key, var in opt_op.inputs.iteritems():
if key in ["Param", "Grad"]:
continue
# update accumulator variable shape
param_shape = new_inputs["Param"].shape
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape)
print("var, new shape", key, var.name, new_shape)
tmpvar = program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=new_shape)
new_inputs[key] = tmpvar
# FIXME: change outputs ParamOut
optimize_sub_program.global_block().append_op(
program.global_block().append_op(
type=opt_op.type,
inputs=new_inputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, opt_op):
def _append_pserver_non_opt_ops(self, program, opt_op):
for _, var in opt_op.inputs.iteritems():
optimize_sub_program.global_block().create_var(
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
optimize_sub_program.global_block().append_op(
program.global_block().append_op(
type=opt_op.type,
inputs=new_inputs,
inputs=opt_op.inputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
......@@ -331,15 +370,15 @@ class DistributeTranspiler:
# step5
pserver_program = Program()
for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_param(pserver_program.global_block(), v)
self._clone_var(pserver_program.global_block(), v)
# step6
optimize_sub_program = Program()
for opt_op in optimize_ops:
if opt_ops.inputs.has_key("Grad"):
if opt_op.inputs.has_key("Grad"):
# append optimize_op
self._append_pserver_ops(opt_op, endpoint)
self._append_pserver_ops(optimize_sub_program, opt_op, endpoint)
else:
self._append_pserver_non_opt_ops(opt_op)
self._append_pserver_non_opt_ops(optimize_sub_program, opt_op)
pserver_program.global_block().append_op(
type="recv",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册