提交 0689e2a5 编写于 作者: C cxxly 提交者: Xiaoxu Chen

fix _strip_grad_suffix_ bugs when input patten is 'x@GRAD@RENAME'

上级 5e5481d8
...@@ -464,9 +464,10 @@ def _strip_grad_suffix_(name): ...@@ -464,9 +464,10 @@ def _strip_grad_suffix_(name):
x@GRAD@GRAD ==> x x@GRAD@GRAD ==> x
y@GRAD@RENAME@1 ==> y y@GRAD@RENAME@1 ==> y
z@GRAD_slice_0@GRAD ==> z@GRAD_slice_0 z@GRAD_slice_0@GRAD ==> z@GRAD_slice_0
grad/grad/z@GRAD@RENAME@block0@1@GRAD ==> z
""" """
pos = re.search(f'{core.grad_var_suffix()}$', name) or re.search( pos = re.search(f'{core.grad_var_suffix()}+@', name) or re.search(
f'{core.grad_var_suffix()}@', name f'{core.grad_var_suffix()}$', name
) )
new_name = name[: pos.start()] if pos is not None else name new_name = name[: pos.start()] if pos is not None else name
new_pos = name.rfind('grad/') new_pos = name.rfind('grad/')
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import fluid, static from paddle import fluid, static
from paddle.fluid import backward
class BackwardNet: class BackwardNet:
...@@ -449,6 +450,19 @@ class TestBackwardUninitializedVariable(unittest.TestCase): ...@@ -449,6 +450,19 @@ class TestBackwardUninitializedVariable(unittest.TestCase):
print(out) print(out)
class TestStripGradSuffix(unittest.TestCase):
def test_strip_grad_suffix(self):
cases = (
('x@GRAD', 'x'),
('x@GRAD@GRAD', 'x'),
('x@GRAD@RENAME@1', 'x'),
('x@GRAD_slice_0@GRAD', 'x@GRAD_slice_0'),
('grad/grad/x@GRAD@RENAME@block0@1@GRAD', 'x'),
)
for input_, desired in cases:
self.assertEqual(backward._strip_grad_suffix_(input_), desired)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册