未验证 提交 d380ad06 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #7837 from helinwang/transpiler_fix

Transpiler: fix pserver crash due to split var name check.
...@@ -33,6 +33,10 @@ class VarBlock: ...@@ -33,6 +33,10 @@ class VarBlock:
return "%s:%d:%d" % (self.varname, self.offset, self.size) return "%s:%d:%d" % (self.varname, self.offset, self.size)
def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block")
def split_dense_variable(var_list, def split_dense_variable(var_list,
pserver_count, pserver_count,
min_block_size=1024, min_block_size=1024,
...@@ -303,8 +307,8 @@ class DistributeTranspiler: ...@@ -303,8 +307,8 @@ class DistributeTranspiler:
return True return True
else: else:
for n in param_names: for n in param_names:
if n.startswith(op.inputs["Param"].name+".block") and \ if same_or_split_var(n, op.inputs[
n != op.inputs["Param"].name: "Param"].name) and n != op.inputs["Param"].name:
return True return True
return False return False
else: else:
...@@ -335,7 +339,7 @@ class DistributeTranspiler: ...@@ -335,7 +339,7 @@ class DistributeTranspiler:
if key == "Grad": if key == "Grad":
grad_block = None grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]: for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if g.name.startswith(var.name): if same_or_split_var(g.name, var.name):
grad_block = g grad_block = g
break break
if not grad_block: if not grad_block:
...@@ -365,7 +369,7 @@ class DistributeTranspiler: ...@@ -365,7 +369,7 @@ class DistributeTranspiler:
# param is already created on global program # param is already created on global program
param_block = None param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]: for p in self.param_grad_ep_mapping[endpoint]["params"]:
if p.name.startswith(var.name): if same_or_split_var(p.name, var.name):
param_block = p param_block = p
break break
if not param_block: if not param_block:
...@@ -502,7 +506,7 @@ class DistributeTranspiler: ...@@ -502,7 +506,7 @@ class DistributeTranspiler:
def _get_splited_name_and_shape(varname): def _get_splited_name_and_shape(varname):
for idx, splited_param in enumerate(params): for idx, splited_param in enumerate(params):
pname = splited_param.name pname = splited_param.name
if pname.startswith(varname) and varname != pname: if same_or_split_var(pname, varname) and varname != pname:
return pname, splited_param.shape return pname, splited_param.shape
return "", [] return "", []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册