未验证 提交 d380ad06 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #7837 from helinwang/transpiler_fix

Transpiler: fix pserver crash due to split var name check.
......@@ -33,6 +33,10 @@ class VarBlock:
return "%s:%d:%d" % (self.varname, self.offset, self.size)
def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block")
def split_dense_variable(var_list,
pserver_count,
min_block_size=1024,
......@@ -303,8 +307,8 @@ class DistributeTranspiler:
return True
else:
for n in param_names:
if n.startswith(op.inputs["Param"].name+".block") and \
n != op.inputs["Param"].name:
if same_or_split_var(n, op.inputs[
"Param"].name) and n != op.inputs["Param"].name:
return True
return False
else:
......@@ -335,7 +339,7 @@ class DistributeTranspiler:
if key == "Grad":
grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if g.name.startswith(var.name):
if same_or_split_var(g.name, var.name):
grad_block = g
break
if not grad_block:
......@@ -365,7 +369,7 @@ class DistributeTranspiler:
# param is already created on global program
param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]:
if p.name.startswith(var.name):
if same_or_split_var(p.name, var.name):
param_block = p
break
if not param_block:
......@@ -502,7 +506,7 @@ class DistributeTranspiler:
def _get_splited_name_and_shape(varname):
for idx, splited_param in enumerate(params):
pname = splited_param.name
if pname.startswith(varname) and varname != pname:
if same_or_split_var(pname, varname) and varname != pname:
return pname, splited_param.shape
return "", []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册