未验证 提交 110f769d 编写于 作者: W WangZhen 提交者: GitHub

[SOT]Fix ast_only_test in paddle ci (#56081)

上级 0e2a4d57
......@@ -422,7 +422,8 @@ class StaticFunction:
# first encouter the bound function of layer and cache it.
new_static_layer = self._clone()
if (
self._dygraph_function.__name__
isinstance(instance, layers.Layer)
and self._dygraph_function.__name__
not in instance._original_funcs.keys()
):
instance._original_funcs[
......
......@@ -74,7 +74,7 @@ set_tests_properties(test_yolov3 PROPERTIES TIMEOUT 900 LABELS
set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120)
set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 150)
set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150)
set_tests_properties(test_bert PROPERTIES TIMEOUT 120)
set_tests_properties(test_bert PROPERTIES TIMEOUT 180)
set_tests_properties(test_bert_with_stride PROPERTIES TIMEOUT 120)
set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120)
set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120)
......
......@@ -35,7 +35,7 @@ def enable_fallback_guard(enable):
def to_ast(func):
"""
convet run fall_back to ast
convert run fall_back to ast
"""
def impl(*args, **kwargs):
......@@ -47,7 +47,7 @@ def to_ast(func):
def to_sot(func):
"""
convet run fall_back to ast
convert run fall_back to ast
"""
enable_sot = os.environ.get("ENABLE_SOT", "False") == "True"
......@@ -65,11 +65,12 @@ def dy2static_unittest(cls):
"""
dy2static unittest must be decorated to each Dy2static Unittests.
run both in Fallback and Ast mode.
Usage like:
@dy2static_unittest
class TestA (unittest.TestCase):
...
Examples:
>>> @dy2static_unittest
... class TestA(unittest.TestCase):
... ...
"""
for key in dir(cls):
if key.startswith("test"):
......@@ -84,16 +85,18 @@ def dy2static_unittest(cls):
def ast_only_test(func):
"""
run this test function in ast only mode.
Usage:
class TestA (unittest.TestCase):
@ast_only_test
def test_ast_only(self):
pass
Examples:
>>> @dy2static_unittest
... class TestA(unittest.TestCase):
... @ast_only_test
... def test_ast_only(self):
... pass
"""
def impl(*args, **kwargs):
if os.environ.get("ENABLE_FALL_BACK", "True") == "False":
if os.environ.get("ENABLE_FALL_BACK", "False") == "False":
func(*args, **kwargs)
return impl
......@@ -102,16 +105,18 @@ def ast_only_test(func):
def sot_only_test(func):
"""
run this test function in ast only mode.
Usage:
class TestA (unittest.TestCase):
@ast_only_test
def test_ast_only(self):
pass
Examples:
>>> @dy2static_unittest
... class TestA(unittest.TestCase):
... @sot_only_test
... def test_sot_only(self):
... pass
"""
def impl(*args, **kwargs):
if os.environ.get("ENABLE_FALL_BACK", "True") == "True":
if os.environ.get("ENABLE_FALL_BACK", "False") == "True":
func(*args, **kwargs)
return impl
......
......@@ -29,7 +29,7 @@ class TestEvalFrame(unittest.TestCase):
def test_eval_frame(self):
if version_info.major != 3 or (
version_info.minor <= 8 or version_info.minor >= 11
version_info.minor <= 8 or version_info.minor >= 12
):
# print("skip test_eval_frame, current only support 3.8 - 3.10")
return
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from dygraph_to_static_util import ast_only_test
from dygraph_to_static_util import ast_only_test, dy2static_unittest
import paddle
from paddle.jit.dy2static.program_translator import StaticFunction
......@@ -85,6 +85,7 @@ class TestRollBackPlainFunction(unittest.TestCase):
np.testing.assert_array_equal(st_out.numpy(), dy_out.numpy())
@dy2static_unittest
class TestRollBackNet(unittest.TestCase):
def setUp(self):
paddle.set_device("cpu")
......@@ -135,7 +136,9 @@ class FuncRollback(paddle.nn.Layer):
return x + 2
@dy2static_unittest
class TestRollBackNotForward(unittest.TestCase):
@ast_only_test
def test_rollback(self):
x = paddle.zeros([2, 2])
net = FuncRollback()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册