From 39b0fdc3afca31f89224534f921f5fdcb48778d2 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 24 Jan 2018 15:14:54 -0800 Subject: [PATCH] Transpiler: fix pserver crash due to split var name check. In notest_dist_label_semantic_roles.py, "emb" is matched with "embedding_1.w_0", but they are two irrevalent vars. Fixes: https://github.com/PaddlePaddle/Paddle/issues/7701 --- python/paddle/v2/fluid/distribute_transpiler.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index abcad899bf..934eba73b8 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -33,6 +33,10 @@ class VarBlock: return "%s:%d:%d" % (self.varname, self.offset, self.size) +def same_or_split_var(p_name, var_name): + return p_name == var_name or p_name.startswith(var_name + ".block") + + def split_dense_variable(var_list, pserver_count, min_block_size=1024, @@ -303,8 +307,8 @@ class DistributeTranspiler: return True else: for n in param_names: - if n.startswith(op.inputs["Param"].name+".block") and \ - n != op.inputs["Param"].name: + if same_or_split_var(n, op.inputs[ + "Param"].name) and n != op.inputs["Param"].name: return True return False else: @@ -335,7 +339,7 @@ class DistributeTranspiler: if key == "Grad": grad_block = None for g in self.param_grad_ep_mapping[endpoint]["grads"]: - if g.name.startswith(var.name): + if same_or_split_var(g.name, var.name): grad_block = g break if not grad_block: @@ -365,7 +369,7 @@ class DistributeTranspiler: # param is already created on global program param_block = None for p in self.param_grad_ep_mapping[endpoint]["params"]: - if p.name.startswith(var.name): + if same_or_split_var(p.name, var.name): param_block = p break if not param_block: @@ -502,7 +506,7 @@ class DistributeTranspiler: def _get_splited_name_and_shape(varname): for idx, splited_param in enumerate(params): pname = splited_param.name - if pname.startswith(varname) and varname != pname: + if same_or_split_var(pname, varname) and varname != pname: return pname, splited_param.shape return "", [] -- GitLab