diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 28c8c739f2efca82401d937a414ae051a785fca0..bd90f6089fe95bad6901038a6f4823642eacbf69 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -87,6 +87,9 @@ WHILE_BODY_PREFIX = 'while_body' FOR_CONDITION_PREFIX = 'for_loop_condition' FOR_BODY_PREFIX = 'for_loop_body' +GRAD_PREFIX = 'grad/' +GRAD_SUFFIX = '@GRAD' + NO_SHAPE_VAR_TYPE = [ core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES, @@ -1463,18 +1466,28 @@ def _param_grad_names(program_desc, params): # the param grad name can be set correctly in the run_program. for param in params: candidate = [] - suffix = param.name + '@GRAD' for var in program_desc.block(0).all_vars(): var_name = var.name() - if var_name.endswith(suffix): - prefix_count = var_name.count('grad/') - if 'grad/' * prefix_count + suffix == var_name: + if param.name not in var_name: + continue + suf_count = var_name.count(GRAD_SUFFIX) + if suf_count > 0: + suffix = param.name + GRAD_SUFFIX * suf_count + pre_count = var_name.count(GRAD_PREFIX) + if GRAD_PREFIX * pre_count + suffix == var_name: candidate.append(var_name) if candidate: - names.append(max(candidate, key=lambda name: name.count('grad/'))) + names.append( + max( + candidate, + key=lambda name: name.count(GRAD_PREFIX) + if GRAD_PREFIX in name + else name.count(GRAD_SUFFIX), + ) + ) else: - names.append(suffix) + names.append(param.name + GRAD_SUFFIX) return names diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py b/test/dygraph_to_static/test_gradname_parse.py similarity index 100% rename from python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py rename to test/dygraph_to_static/test_gradname_parse.py