未验证 提交 f4ae3737 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Fix _param_grad_names when grad name likes 'param@GRAD@GRAD' (#52821)

* Fix _param_grad_names when like 'param@GRAD@GRAD'
上级 0b98d1aa
......@@ -87,6 +87,9 @@ WHILE_BODY_PREFIX = 'while_body'
FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
GRAD_PREFIX = 'grad/'
GRAD_SUFFIX = '@GRAD'
NO_SHAPE_VAR_TYPE = [
core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES,
......@@ -1463,18 +1466,28 @@ def _param_grad_names(program_desc, params):
# the param grad name can be set correctly in the run_program.
for param in params:
candidate = []
suffix = param.name + '@GRAD'
for var in program_desc.block(0).all_vars():
var_name = var.name()
if var_name.endswith(suffix):
prefix_count = var_name.count('grad/')
if 'grad/' * prefix_count + suffix == var_name:
if param.name not in var_name:
continue
suf_count = var_name.count(GRAD_SUFFIX)
if suf_count > 0:
suffix = param.name + GRAD_SUFFIX * suf_count
pre_count = var_name.count(GRAD_PREFIX)
if GRAD_PREFIX * pre_count + suffix == var_name:
candidate.append(var_name)
if candidate:
names.append(max(candidate, key=lambda name: name.count('grad/')))
names.append(
max(
candidate,
key=lambda name: name.count(GRAD_PREFIX)
if GRAD_PREFIX in name
else name.count(GRAD_SUFFIX),
)
)
else:
names.append(suffix)
names.append(param.name + GRAD_SUFFIX)
return names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册