提交 50a02adf 编写于 作者: T typhoonzero

transpile program ok

上级 9c0b1cf1
...@@ -114,12 +114,15 @@ class DistributeTranspiler: ...@@ -114,12 +114,15 @@ class DistributeTranspiler:
# step3 # step3
send_inputs = [] send_inputs = []
send_outputs = [] send_outputs = []
for _, splited in grad_var_mapping.iteritems(): for b in grad_blocks: # append by order
send_inputs.extend(splited) varname, block_id, _ = b.split(":")
send_inputs.append(grad_var_mapping[varname][int(block_id)])
param_var_mapping = self._create_vars_from_blocklist(program, param_var_mapping = self._create_vars_from_blocklist(program,
param_blocks) param_blocks)
for _, splited in param_var_mapping.iteritems(): for b in param_blocks:
send_outputs.extend(splited) varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var, eplist is of the same # let send_op know which endpoint to send which var, eplist is of the same
# order of send_inputs. # order of send_inputs.
eplist = split_method(send_inputs, pserver_endpoints) eplist = split_method(send_inputs, pserver_endpoints)
...@@ -243,8 +246,37 @@ class DistributeTranspiler: ...@@ -243,8 +246,37 @@ class DistributeTranspiler:
var_list.append(var_each) var_list.append(var_each)
return var_list return var_list
def _append_pserver_ops(self, opt_op, endpoint): def _get_optimizer_input_shape(self, op_type, varkey, orig_shape,
param_shape):
"""
Returns the shape for optimizer inputs that need to be reshaped when
Param and Grad is splited to multiple servers.
"""
# HACK(typhoonzero): Should use functions of corresponding optimizer in
# optimizer.py to get the shape, do not bind this in the transpiler.
if op_type == "adam":
if varkey in ["Moment1", "Moment2"]:
return param_shape
elif op_type == "adagrad":
if varkey == "Moment":
return param_shape
elif op_type == "adamax":
if varkey in ["Moment", "InfNorm"]:
return param_shape
elif op_type == "momentum":
if varkey == "Velocity":
return param_shape
elif op_type == "":
if varkey == "Moment":
return param_shape
elif op_type == "sgd":
pass
return orig_shape
def _append_pserver_ops(self, program, opt_op, endpoint):
new_inputs = dict() new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
for key, var in opt_op.inputs.iteritems(): for key, var in opt_op.inputs.iteritems():
if key == "Grad": if key == "Grad":
grad_block = None grad_block = None
...@@ -256,7 +288,7 @@ class DistributeTranspiler: ...@@ -256,7 +288,7 @@ class DistributeTranspiler:
# do not append this op if current endpoint # do not append this op if current endpoint
# is not dealing with this grad block # is not dealing with this grad block
return return
merged_var = optimize_sub_program.global_block().create_var( merged_var = program.global_block().create_var(
name=grad_block.name, name=grad_block.name,
persistable=grad_block.persistable, persistable=grad_block.persistable,
dtype=grad_block.dtype, dtype=grad_block.dtype,
...@@ -264,13 +296,12 @@ class DistributeTranspiler: ...@@ -264,13 +296,12 @@ class DistributeTranspiler:
# append merging ops if trainers > 1 # append merging ops if trainers > 1
if self.trainers > 1: if self.trainers > 1:
vars2merge = self._create_var_for_trainers( vars2merge = self._create_var_for_trainers(
optimize_sub_program.global_block(), grad_block, program.global_block(), grad_block, self.trainers)
self.trainers) program.global_block().append_op(
optimize_sub_program.global_block().append_op(
type="sum", type="sum",
inputs={"X": vars2merge}, inputs={"X": vars2merge},
outputs={"Out": merged_var}) outputs={"Out": merged_var})
optimize_sub_program.global_block().append_op( program.global_block().append_op(
type="scale", type="scale",
inputs={"X": merged_var}, inputs={"X": merged_var},
outputs={"Out": merged_var}, outputs={"Out": merged_var},
...@@ -285,37 +316,45 @@ class DistributeTranspiler: ...@@ -285,37 +316,45 @@ class DistributeTranspiler:
break break
if not param_block: if not param_block:
return return
tmpvar = optimize_sub_program.global_block().create_var( tmpvar = program.global_block().create_var(
name=param_block.name, name=param_block.name,
persistable=param_block.persistable, persistable=param_block.persistable,
dtype=param_block.dtype, dtype=param_block.dtype,
shape=param_block.shape) shape=param_block.shape)
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
else:
tmpvar = optimize_sub_program.global_block().create_var( for key, var in opt_op.inputs.iteritems():
if key in ["Param", "Grad"]:
continue
# update accumulator variable shape
param_shape = new_inputs["Param"].shape
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape)
print("var, new shape", key, var.name, new_shape)
tmpvar = program.global_block().create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
shape=var.shape) shape=new_shape)
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
# FIXME: change outputs ParamOut # FIXME: change outputs ParamOut
optimize_sub_program.global_block().append_op( program.global_block().append_op(
type=opt_op.type, type=opt_op.type,
inputs=new_inputs, inputs=new_inputs,
outputs=opt_op.outputs, outputs=opt_op.outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, opt_op): def _append_pserver_non_opt_ops(self, program, opt_op):
for _, var in opt_op.inputs.iteritems(): for _, var in opt_op.inputs.iteritems():
optimize_sub_program.global_block().create_var( program.global_block().create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
shape=var.shape) shape=var.shape)
optimize_sub_program.global_block().append_op( program.global_block().append_op(
type=opt_op.type, type=opt_op.type,
inputs=new_inputs, inputs=opt_op.inputs,
outputs=opt_op.outputs, outputs=opt_op.outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
...@@ -331,15 +370,15 @@ class DistributeTranspiler: ...@@ -331,15 +370,15 @@ class DistributeTranspiler:
# step5 # step5
pserver_program = Program() pserver_program = Program()
for v in self.param_grad_ep_mapping[endpoint]["params"]: for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_param(pserver_program.global_block(), v) self._clone_var(pserver_program.global_block(), v)
# step6 # step6
optimize_sub_program = Program() optimize_sub_program = Program()
for opt_op in optimize_ops: for opt_op in optimize_ops:
if opt_ops.inputs.has_key("Grad"): if opt_op.inputs.has_key("Grad"):
# append optimize_op # append optimize_op
self._append_pserver_ops(opt_op, endpoint) self._append_pserver_ops(optimize_sub_program, opt_op, endpoint)
else: else:
self._append_pserver_non_opt_ops(opt_op) self._append_pserver_non_opt_ops(optimize_sub_program, opt_op)
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="recv", type="recv",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册