提交 6fa56b9d 编写于 作者: T typhoonzero

left startup program bug

上级 50a02adf
......@@ -56,8 +56,6 @@ def split_dense_variable(var_list,
(block_id) * block_size))
block = VarBlock(var.name, block_id, curr_block_size)
blocks.append(str(block))
print("$$ splited var: ", var.name, var.shape, split_count, len(blocks),
block_size)
return blocks
......@@ -126,7 +124,7 @@ class DistributeTranspiler:
# let send_op know which endpoint to send which var, eplist is of the same
# order of send_inputs.
eplist = split_method(send_inputs, pserver_endpoints)
# create mapping of endpoint -> var to create pserver side program
# create mapping of endpoint -> splited var to create pserver side program
self.param_grad_ep_mapping = dict()
for i, ep in enumerate(eplist):
param = send_outputs[i]
......@@ -142,7 +140,6 @@ class DistributeTranspiler:
outputs={"Out": send_outputs},
attrs={"endpoints": pserver_endpoints,
"epmap": eplist})
# step4
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
......@@ -187,21 +184,6 @@ class DistributeTranspiler:
var_mapping[varname].append(var)
return var_mapping
def _clone_param(self, block, v):
assert isinstance(v, Parameter)
new_p = Parameter(
block=block,
shape=v.shape,
dtype=v.dtype,
type=v.type,
lod_level=v.lod_level,
stop_gradient=v.stop_gradient,
trainable=v.trainable,
optimize_attr=v.optimize_attr,
regularizer=v.regularizer,
name=v.name)
block.vars[new_p.name] = new_p
def _clone_var(self, block, var):
assert isinstance(var, Variable)
return block.create_var(
......@@ -210,7 +192,9 @@ class DistributeTranspiler:
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=var.persistable)
# HACK: let all param in pserver persistable so child
# program in recv can get them
persistable=True)
def _append_split_op(self, program, gradblocks):
var_mapping = self._create_vars_from_blocklist(program, gradblocks)
......@@ -318,9 +302,10 @@ class DistributeTranspiler:
return
tmpvar = program.global_block().create_var(
name=param_block.name,
persistable=param_block.persistable,
persistable=True,
dtype=param_block.dtype,
shape=param_block.shape)
new_inputs[key] = tmpvar
for key, var in opt_op.inputs.iteritems():
......@@ -330,7 +315,6 @@ class DistributeTranspiler:
param_shape = new_inputs["Param"].shape
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape)
print("var, new shape", key, var.name, new_shape)
tmpvar = program.global_block().create_var(
name=var.name,
persistable=var.persistable,
......@@ -338,7 +322,8 @@ class DistributeTranspiler:
shape=new_shape)
new_inputs[key] = tmpvar
# FIXME: change outputs ParamOut
# change outputs ParamOut variable
opt_op.outputs["ParamOut"] = new_inputs["Param"]
program.global_block().append_op(
type=opt_op.type,
inputs=new_inputs,
......@@ -380,6 +365,7 @@ class DistributeTranspiler:
else:
self._append_pserver_non_opt_ops(optimize_sub_program, opt_op)
print("####", optimize_sub_program)
pserver_program.global_block().append_op(
type="recv",
inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"]
......@@ -400,3 +386,53 @@ class DistributeTranspiler:
})
pserver_program.sync_with_cpp()
return pserver_program
def get_startup_program(self, endpoint):
"""
Get startup program for current parameter server.
Modify operator input variables if there are variables that
was splited to several blocks.
"""
s_prog = Program()
orig_s_prog = framework.default_startup_program()
params = self.param_grad_ep_mapping[endpoint]["params"]
def _get_splited_name_and_shape(varname):
for idx, splited_param in enumerate(params):
pname = splited_param.name
if pname.startswith(varname) and varname != pname:
return pname, splited_param.shape
return "", []
# 1. create vars
created_var_map = dict()
for var in params:
print("%%%% append var", var.name, var.shape)
tmpvar = s_prog.global_block().create_var(
name=var.name,
persistable=True,
dtype=var.dtype,
shape=var.shape)
created_var_map[var.name] = tmpvar
# 2. rename op outputs
for op in orig_s_prog.global_block().ops:
new_outputs = dict()
for key, var in op.outputs.iteritems():
newname, _ = _get_splited_name_and_shape(var.name)
if newname:
new_outputs[key] = created_var_map[newname]
else:
new_outputs[key] = var
# do not append startup op if var is not on this pserver
var_on_pserver = False
for _, var in new_outputs.iteritems():
if var.name in created_var_map:
var_on_pserver = True
if var_on_pserver:
s_prog.global_block().append_op(
type=op.type,
inputs=op.inputs,
outputs=new_outputs,
attrs=op.attrs)
return s_prog
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册