未验证 提交 110f769d 编写于 作者: W WangZhen 提交者: GitHub

[SOT]Fix ast_only_test in paddle ci (#56081)

上级 0e2a4d57
...@@ -422,7 +422,8 @@ class StaticFunction: ...@@ -422,7 +422,8 @@ class StaticFunction:
# first encouter the bound function of layer and cache it. # first encouter the bound function of layer and cache it.
new_static_layer = self._clone() new_static_layer = self._clone()
if ( if (
self._dygraph_function.__name__ isinstance(instance, layers.Layer)
and self._dygraph_function.__name__
not in instance._original_funcs.keys() not in instance._original_funcs.keys()
): ):
instance._original_funcs[ instance._original_funcs[
......
...@@ -74,7 +74,7 @@ set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS ...@@ -74,7 +74,7 @@ set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS
set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120) set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120)
set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 150) set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 150)
set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150) set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150)
set_tests_properties(test_bert PROPERTIES TIMEOUT 120) set_tests_properties(test_bert PROPERTIES TIMEOUT 180)
set_tests_properties(test_bert_with_stride PROPERTIES TIMEOUT 120) set_tests_properties(test_bert_with_stride PROPERTIES TIMEOUT 120)
set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120) set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120)
set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120) set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120)
......
...@@ -35,7 +35,7 @@ def enable_fallback_guard(enable): ...@@ -35,7 +35,7 @@ def enable_fallback_guard(enable):
def to_ast(func): def to_ast(func):
""" """
convet run fall_back to ast convert run fall_back to ast
""" """
def impl(*args, **kwargs): def impl(*args, **kwargs):
...@@ -47,7 +47,7 @@ def to_ast(func): ...@@ -47,7 +47,7 @@ def to_ast(func):
def to_sot(func): def to_sot(func):
""" """
convet run fall_back to ast convert run fall_back to ast
""" """
enable_sot = os.environ.get("ENABLE_SOT", "False") == "True" enable_sot = os.environ.get("ENABLE_SOT", "False") == "True"
...@@ -65,11 +65,12 @@ def dy2static_unittest(cls): ...@@ -65,11 +65,12 @@ def dy2static_unittest(cls):
""" """
dy2static unittest must be decorated to each Dy2static Unittests. dy2static unittest must be decorated to each Dy2static Unittests.
run both in Fallback and Ast mode. run both in Fallback and Ast mode.
Usage like:
@dy2static_unittest Examples:
class TestA (unittest.TestCase):
... >>> @dy2static_unittest
... class TestA(unittest.TestCase):
... ...
""" """
for key in dir(cls): for key in dir(cls):
if key.startswith("test"): if key.startswith("test"):
...@@ -84,16 +85,18 @@ def dy2static_unittest(cls): ...@@ -84,16 +85,18 @@ def dy2static_unittest(cls):
def ast_only_test(func): def ast_only_test(func):
""" """
run this test function in ast only mode. run this test function in ast only mode.
Usage:
class TestA (unittest.TestCase): Examples:
@ast_only_test
def test_ast_only(self): >>> @dy2static_unittest
pass ... class TestA(unittest.TestCase):
... @ast_only_test
... def test_ast_only(self):
... pass
""" """
def impl(*args, **kwargs): def impl(*args, **kwargs):
if os.environ.get("ENABLE_FALL_BACK", "True") == "False": if os.environ.get("ENABLE_FALL_BACK", "False") == "False":
func(*args, **kwargs) func(*args, **kwargs)
return impl return impl
...@@ -102,16 +105,18 @@ def ast_only_test(func): ...@@ -102,16 +105,18 @@ def ast_only_test(func):
def sot_only_test(func): def sot_only_test(func):
""" """
run this test function in ast only mode. run this test function in ast only mode.
Usage:
class TestA (unittest.TestCase): Examples:
@ast_only_test
def test_ast_only(self): >>> @dy2static_unittest
pass ... class TestA(unittest.TestCase):
... @sot_only_test
... def test_sot_only(self):
... pass
""" """
def impl(*args, **kwargs): def impl(*args, **kwargs):
if os.environ.get("ENABLE_FALL_BACK", "True") == "True": if os.environ.get("ENABLE_FALL_BACK", "False") == "True":
func(*args, **kwargs) func(*args, **kwargs)
return impl return impl
......
...@@ -29,7 +29,7 @@ class TestEvalFrame(unittest.TestCase): ...@@ -29,7 +29,7 @@ class TestEvalFrame(unittest.TestCase):
def test_eval_frame(self): def test_eval_frame(self):
if version_info.major != 3 or ( if version_info.major != 3 or (
version_info.minor <= 8 or version_info.minor >= 11 version_info.minor <= 8 or version_info.minor >= 12
): ):
# print("skip test_eval_frame, current only support 3.8 - 3.10") # print("skip test_eval_frame, current only support 3.8 - 3.10")
return return
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dygraph_to_static_util import ast_only_test from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle import paddle
from paddle.jit.dy2static.program_translator import StaticFunction from paddle.jit.dy2static.program_translator import StaticFunction
...@@ -85,6 +85,7 @@ class TestRollBackPlainFunction(unittest.TestCase): ...@@ -85,6 +85,7 @@ class TestRollBackPlainFunction(unittest.TestCase):
np.testing.assert_array_equal(st_out.numpy(), dy_out.numpy()) np.testing.assert_array_equal(st_out.numpy(), dy_out.numpy())
@dy2static_unittest
class TestRollBackNet(unittest.TestCase): class TestRollBackNet(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.set_device("cpu") paddle.set_device("cpu")
...@@ -135,7 +136,9 @@ class FuncRollback(paddle.nn.Layer): ...@@ -135,7 +136,9 @@ class FuncRollback(paddle.nn.Layer):
return x + 2 return x + 2
@dy2static_unittest
class TestRollBackNotForward(unittest.TestCase): class TestRollBackNotForward(unittest.TestCase):
@ast_only_test
def test_rollback(self): def test_rollback(self):
x = paddle.zeros([2, 2]) x = paddle.zeros([2, 2])
net = FuncRollback() net = FuncRollback()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册