diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index ff57966eefbfb3eac97019d9b858e8da01987303..51fe6294a5077abfe495ea00f956100e948a1fed 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -464,9 +464,10 @@ def _strip_grad_suffix_(name): x@GRAD@GRAD ==> x y@GRAD@RENAME@1 ==> y z@GRAD_slice_0@GRAD ==> z@GRAD_slice_0 + grad/grad/z@GRAD@RENAME@block0@1@GRAD ==> z """ - pos = re.search(f'{core.grad_var_suffix()}$', name) or re.search( - f'{core.grad_var_suffix()}@', name + pos = re.search(f'{core.grad_var_suffix()}+@', name) or re.search( + f'{core.grad_var_suffix()}$', name ) new_name = name[: pos.start()] if pos is not None else name new_pos = name.rfind('grad/') diff --git a/python/paddle/fluid/tests/unittests/test_backward.py b/python/paddle/fluid/tests/unittests/test_backward.py index b2abf39fecac50edf025ceaf309f12384e7f433d..031664e4cb35d6d4339d967cdb94ca9afe16fa14 100644 --- a/python/paddle/fluid/tests/unittests/test_backward.py +++ b/python/paddle/fluid/tests/unittests/test_backward.py @@ -19,6 +19,7 @@ import numpy as np import paddle import paddle.nn.functional as F from paddle import fluid, static +from paddle.fluid import backward class BackwardNet: @@ -449,6 +450,19 @@ class TestBackwardUninitializedVariable(unittest.TestCase): print(out) +class TestStripGradSuffix(unittest.TestCase): + def test_strip_grad_suffix(self): + cases = ( + ('x@GRAD', 'x'), + ('x@GRAD@GRAD', 'x'), + ('x@GRAD@RENAME@1', 'x'), + ('x@GRAD_slice_0@GRAD', 'x@GRAD_slice_0'), + ('grad/grad/x@GRAD@RENAME@block0@1@GRAD', 'x'), + ) + for input_, desired in cases: + self.assertEqual(backward._strip_grad_suffix_(input_), desired) + + if __name__ == '__main__': paddle.enable_static() unittest.main()