diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 6ecdc8cb26accb1ee0c770205e27b1a1ac5ca64c..ea5c159c579cbacfca35ca71120e5e1c6ab4c375 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -422,7 +422,8 @@ class StaticFunction: # first encouter the bound function of layer and cache it. new_static_layer = self._clone() if ( - self._dygraph_function.__name__ + isinstance(instance, layers.Layer) + and self._dygraph_function.__name__ not in instance._original_funcs.keys() ): instance._original_funcs[ diff --git a/test/dygraph_to_static/CMakeLists.txt b/test/dygraph_to_static/CMakeLists.txt index 858ace95250b18e0e322f2f17a1512e13da9120b..61c47e1708d5aaa6064f16ac205ddf90c4f8b434 100644 --- a/test/dygraph_to_static/CMakeLists.txt +++ b/test/dygraph_to_static/CMakeLists.txt @@ -74,7 +74,7 @@ set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120) set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 150) set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150) -set_tests_properties(test_bert PROPERTIES TIMEOUT 120) +set_tests_properties(test_bert PROPERTIES TIMEOUT 180) set_tests_properties(test_bert_with_stride PROPERTIES TIMEOUT 120) set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120) set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120) diff --git a/test/dygraph_to_static/dygraph_to_static_util.py b/test/dygraph_to_static/dygraph_to_static_util.py index c1792b7486180e85c4ec87e24df625ef07cc1348..ca04f8d0c2046d39a488ed43a1a3dcdc7576f325 100644 --- a/test/dygraph_to_static/dygraph_to_static_util.py +++ b/test/dygraph_to_static/dygraph_to_static_util.py @@ -35,7 +35,7 @@ def enable_fallback_guard(enable): def to_ast(func): """ - convet run fall_back to ast + convert run fall_back to ast """ def impl(*args, **kwargs): @@ -47,7 +47,7 @@ def to_ast(func): def to_sot(func): """ - convet run fall_back to ast + convert run fall_back to ast """ enable_sot = os.environ.get("ENABLE_SOT", "False") == "True" @@ -65,11 +65,12 @@ def dy2static_unittest(cls): """ dy2static unittest must be decorated to each Dy2static Unittests. run both in Fallback and Ast mode. - Usage like: - @dy2static_unittest - class TestA (unittest.TestCase): - ... + Examples: + + >>> @dy2static_unittest + ... class TestA(unittest.TestCase): + ... ... """ for key in dir(cls): if key.startswith("test"): @@ -84,16 +85,18 @@ def dy2static_unittest(cls): def ast_only_test(func): """ run this test function in ast only mode. - Usage: - class TestA (unittest.TestCase): - @ast_only_test - def test_ast_only(self): - pass + Examples: + + >>> @dy2static_unittest + ... class TestA(unittest.TestCase): + ... @ast_only_test + ... def test_ast_only(self): + ... pass """ def impl(*args, **kwargs): - if os.environ.get("ENABLE_FALL_BACK", "True") == "False": + if os.environ.get("ENABLE_FALL_BACK", "False") == "False": func(*args, **kwargs) return impl @@ -102,16 +105,18 @@ def ast_only_test(func): def sot_only_test(func): """ run this test function in ast only mode. - Usage: - class TestA (unittest.TestCase): - @ast_only_test - def test_ast_only(self): - pass + Examples: + + >>> @dy2static_unittest + ... class TestA(unittest.TestCase): + ... @sot_only_test + ... def test_sot_only(self): + ... pass """ def impl(*args, **kwargs): - if os.environ.get("ENABLE_FALL_BACK", "True") == "True": + if os.environ.get("ENABLE_FALL_BACK", "False") == "True": func(*args, **kwargs) return impl diff --git a/test/dygraph_to_static/test_eval_frame.py b/test/dygraph_to_static/test_eval_frame.py index 8584f776bce408df9b1b4b4b56dfe967b2078b35..dfa5e04b44ffbb0edf6a95f1289ae3a4550a5fd2 100644 --- a/test/dygraph_to_static/test_eval_frame.py +++ b/test/dygraph_to_static/test_eval_frame.py @@ -29,7 +29,7 @@ class TestEvalFrame(unittest.TestCase): def test_eval_frame(self): if version_info.major != 3 or ( - version_info.minor <= 8 or version_info.minor >= 11 + version_info.minor <= 8 or version_info.minor >= 12 ): # print("skip test_eval_frame, current only support 3.8 - 3.10") return diff --git a/test/dygraph_to_static/test_rollback.py b/test/dygraph_to_static/test_rollback.py index 80b78d52344d6ad228888864d1870e057c36a0f1..882546be097c03fea6f9aeb0a48f59da4e7d0555 100644 --- a/test/dygraph_to_static/test_rollback.py +++ b/test/dygraph_to_static/test_rollback.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle.jit.dy2static.program_translator import StaticFunction @@ -85,6 +85,7 @@ class TestRollBackPlainFunction(unittest.TestCase): np.testing.assert_array_equal(st_out.numpy(), dy_out.numpy()) +@dy2static_unittest class TestRollBackNet(unittest.TestCase): def setUp(self): paddle.set_device("cpu") @@ -135,7 +136,9 @@ class FuncRollback(paddle.nn.Layer): return x + 2 +@dy2static_unittest class TestRollBackNotForward(unittest.TestCase): + @ast_only_test def test_rollback(self): x = paddle.zeros([2, 2]) net = FuncRollback()