From f4ae373770c4482d3cd26f1be8ee4d97ac3794f7 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Thu, 13 Apr 2023 14:01:54 +0800 Subject: [PATCH] [Dy2St]Fix _param_grad_names when grad name likes 'param@GRAD@GRAD' (#52821) * Fix _param_grad_names when like 'param@GRAD@GRAD' --- python/paddle/jit/dy2static/utils.py | 25 ++++++++++++++----- .../dygraph_to_static/test_gradname_parse.py | 0 2 files changed, 19 insertions(+), 6 deletions(-) rename {python/paddle/fluid/tests/unittests => test}/dygraph_to_static/test_gradname_parse.py (100%) diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 28c8c739f2e..bd90f6089fe 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -87,6 +87,9 @@ WHILE_BODY_PREFIX = 'while_body' FOR_CONDITION_PREFIX = 'for_loop_condition' FOR_BODY_PREFIX = 'for_loop_body' +GRAD_PREFIX = 'grad/' +GRAD_SUFFIX = '@GRAD' + NO_SHAPE_VAR_TYPE = [ core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES, @@ -1463,18 +1466,28 @@ def _param_grad_names(program_desc, params): # the param grad name can be set correctly in the run_program. for param in params: candidate = [] - suffix = param.name + '@GRAD' for var in program_desc.block(0).all_vars(): var_name = var.name() - if var_name.endswith(suffix): - prefix_count = var_name.count('grad/') - if 'grad/' * prefix_count + suffix == var_name: + if param.name not in var_name: + continue + suf_count = var_name.count(GRAD_SUFFIX) + if suf_count > 0: + suffix = param.name + GRAD_SUFFIX * suf_count + pre_count = var_name.count(GRAD_PREFIX) + if GRAD_PREFIX * pre_count + suffix == var_name: candidate.append(var_name) if candidate: - names.append(max(candidate, key=lambda name: name.count('grad/'))) + names.append( + max( + candidate, + key=lambda name: name.count(GRAD_PREFIX) + if GRAD_PREFIX in name + else name.count(GRAD_SUFFIX), + ) + ) else: - names.append(suffix) + names.append(param.name + GRAD_SUFFIX) return names diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py b/test/dygraph_to_static/test_gradname_parse.py similarity index 100% rename from python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradname_parse.py rename to test/dygraph_to_static/test_gradname_parse.py -- GitLab