From 110f769de06ca63d60aa12231d5244b3d0719d4d Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Thu, 10 Aug 2023 14:21:28 +0800 Subject: [PATCH] [SOT]Fix ast_only_test in paddle ci (#56081) --- .../jit/dy2static/program_translator.py | 3 +- test/dygraph_to_static/CMakeLists.txt | 2 +- .../dygraph_to_static_util.py | 41 +++++++++++-------- test/dygraph_to_static/test_eval_frame.py | 2 +- test/dygraph_to_static/test_rollback.py | 5 ++- 5 files changed, 31 insertions(+), 22 deletions(-) diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 6ecdc8cb26a..ea5c159c579 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -422,7 +422,8 @@ class StaticFunction: # first encouter the bound function of layer and cache it. new_static_layer = self._clone() if ( - self._dygraph_function.__name__ + isinstance(instance, layers.Layer) + and self._dygraph_function.__name__ not in instance._original_funcs.keys() ): instance._original_funcs[ diff --git a/test/dygraph_to_static/CMakeLists.txt b/test/dygraph_to_static/CMakeLists.txt index 858ace95250..61c47e1708d 100644 --- a/test/dygraph_to_static/CMakeLists.txt +++ b/test/dygraph_to_static/CMakeLists.txt @@ -74,7 +74,7 @@ set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120) set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 150) set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150) -set_tests_properties(test_bert PROPERTIES TIMEOUT 120) +set_tests_properties(test_bert PROPERTIES TIMEOUT 180) set_tests_properties(test_bert_with_stride PROPERTIES TIMEOUT 120) set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120) set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120) diff --git a/test/dygraph_to_static/dygraph_to_static_util.py b/test/dygraph_to_static/dygraph_to_static_util.py index c1792b74861..ca04f8d0c20 100644 --- a/test/dygraph_to_static/dygraph_to_static_util.py +++ b/test/dygraph_to_static/dygraph_to_static_util.py @@ -35,7 +35,7 @@ def enable_fallback_guard(enable): def to_ast(func): """ - convet run fall_back to ast + convert run fall_back to ast """ def impl(*args, **kwargs): @@ -47,7 +47,7 @@ def to_ast(func): def to_sot(func): """ - convet run fall_back to ast + convert run fall_back to ast """ enable_sot = os.environ.get("ENABLE_SOT", "False") == "True" @@ -65,11 +65,12 @@ def dy2static_unittest(cls): """ dy2static unittest must be decorated to each Dy2static Unittests. run both in Fallback and Ast mode. - Usage like: - @dy2static_unittest - class TestA (unittest.TestCase): - ... + Examples: + + >>> @dy2static_unittest + ... class TestA(unittest.TestCase): + ... ... """ for key in dir(cls): if key.startswith("test"): @@ -84,16 +85,18 @@ def dy2static_unittest(cls): def ast_only_test(func): """ run this test function in ast only mode. - Usage: - class TestA (unittest.TestCase): - @ast_only_test - def test_ast_only(self): - pass + Examples: + + >>> @dy2static_unittest + ... class TestA(unittest.TestCase): + ... @ast_only_test + ... def test_ast_only(self): + ... pass """ def impl(*args, **kwargs): - if os.environ.get("ENABLE_FALL_BACK", "True") == "False": + if os.environ.get("ENABLE_FALL_BACK", "False") == "False": func(*args, **kwargs) return impl @@ -102,16 +105,18 @@ def ast_only_test(func): def sot_only_test(func): """ run this test function in ast only mode. - Usage: - class TestA (unittest.TestCase): - @ast_only_test - def test_ast_only(self): - pass + Examples: + + >>> @dy2static_unittest + ... class TestA(unittest.TestCase): + ... @sot_only_test + ... def test_sot_only(self): + ... pass """ def impl(*args, **kwargs): - if os.environ.get("ENABLE_FALL_BACK", "True") == "True": + if os.environ.get("ENABLE_FALL_BACK", "False") == "True": func(*args, **kwargs) return impl diff --git a/test/dygraph_to_static/test_eval_frame.py b/test/dygraph_to_static/test_eval_frame.py index 8584f776bce..dfa5e04b44f 100644 --- a/test/dygraph_to_static/test_eval_frame.py +++ b/test/dygraph_to_static/test_eval_frame.py @@ -29,7 +29,7 @@ class TestEvalFrame(unittest.TestCase): def test_eval_frame(self): if version_info.major != 3 or ( - version_info.minor <= 8 or version_info.minor >= 11 + version_info.minor <= 8 or version_info.minor >= 12 ): # print("skip test_eval_frame, current only support 3.8 - 3.10") return diff --git a/test/dygraph_to_static/test_rollback.py b/test/dygraph_to_static/test_rollback.py index 80b78d52344..882546be097 100644 --- a/test/dygraph_to_static/test_rollback.py +++ b/test/dygraph_to_static/test_rollback.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle.jit.dy2static.program_translator import StaticFunction @@ -85,6 +85,7 @@ class TestRollBackPlainFunction(unittest.TestCase): np.testing.assert_array_equal(st_out.numpy(), dy_out.numpy()) +@dy2static_unittest class TestRollBackNet(unittest.TestCase): def setUp(self): paddle.set_device("cpu") @@ -135,7 +136,9 @@ class FuncRollback(paddle.nn.Layer): return x + 2 +@dy2static_unittest class TestRollBackNotForward(unittest.TestCase): + @ast_only_test def test_rollback(self): x = paddle.zeros([2, 2]) net = FuncRollback() -- GitLab