未验证 提交 b67715a4 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] add decorator for dy2st test with new ir (#55840)

* add decorator for new_ir_test

* fix bug and only test in ci-coverage

* fix bug and only test in ci-coverage

* fix

* fix bugs

* fix

* fix
上级 697c712f
...@@ -3962,6 +3962,7 @@ function main() { ...@@ -3962,6 +3962,7 @@ function main() {
check_coverage_build check_coverage_build
;; ;;
gpu_cicheck_coverage) gpu_cicheck_coverage)
export FLAGS_NEW_IR_DY2ST_TEST=True
parallel_test parallel_test
check_coverage check_coverage
;; ;;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import collections import collections
import inspect import inspect
import os
import textwrap import textwrap
import threading import threading
import warnings import warnings
...@@ -192,6 +193,7 @@ class CacheKey: ...@@ -192,6 +193,7 @@ class CacheKey:
'class_instance', 'class_instance',
'kwargs', 'kwargs',
'_spec_names_id', '_spec_names_id',
'_new_ir_flags',
] ]
def __init__( def __init__(
...@@ -221,6 +223,9 @@ class CacheKey: ...@@ -221,6 +223,9 @@ class CacheKey:
self._spec_names_id = _hash_spec_names( self._spec_names_id = _hash_spec_names(
input_args_with_spec, input_kwargs_with_spec input_args_with_spec, input_kwargs_with_spec
) )
self._new_ir_flags = os.environ.get(
'FLAGS_enable_new_ir_in_executor', None
)
@classmethod @classmethod
def from_func_and_args(cls, function_spec, args, kwargs, class_instance): def from_func_and_args(cls, function_spec, args, kwargs, class_instance):
...@@ -264,6 +269,7 @@ class CacheKey: ...@@ -264,6 +269,7 @@ class CacheKey:
self.class_instance, self.class_instance,
with_hook, with_hook,
is_train, is_train,
self._new_ir_flags,
) )
) )
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import wraps
import numpy as np
from paddle import set_flags, static
from paddle.fluid import core
def test_with_new_ir(func):
@wraps(func)
def impl(*args, **kwargs):
ir_outs = None
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
try:
new_ir_flag = 'FLAGS_enable_new_ir_in_executor'
os.environ[new_ir_flag] = 'True'
set_flags({new_ir_flag: True})
ir_outs = func(*args, **kwargs)
finally:
del os.environ[new_ir_flag]
set_flags({new_ir_flag: False})
return ir_outs
return impl
def test_and_compare_with_new_ir(need_check_output: bool = True):
def decorator(func):
@wraps(func)
def impl(*args, **kwargs):
outs = func(*args, **kwargs)
if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled():
return outs
# only run in CI-Coverage
if os.environ.get('FLAGS_NEW_IR_DY2ST_TEST', None) is None:
return outs
ir_outs = test_with_new_ir(func)(*args, **kwargs)
if not need_check_output:
return outs
for i in range(len(outs)):
np.testing.assert_array_equal(
outs[i],
ir_outs[i],
err_msg='Dy2St Unittest Check ('
+ func.__name__
+ ') has diff '
+ '\nExpect '
+ str(outs[i])
+ '\n'
+ 'But Got'
+ str(ir_outs[i]),
)
return outs
return impl
return decorator
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from dy2st_test_utils import test_and_compare_with_new_ir
from paddle import fluid from paddle import fluid
from paddle.jit.api import to_static from paddle.jit.api import to_static
...@@ -88,6 +89,7 @@ class TestCastBase(unittest.TestCase): ...@@ -88,6 +89,7 @@ class TestCastBase(unittest.TestCase):
res = self.func(self.input) res = self.func(self.input)
return res return res
@test_and_compare_with_new_ir(False)
def test_cast_result(self): def test_cast_result(self):
res = self.do_test().numpy() res = self.do_test().numpy()
self.assertTrue( self.assertTrue(
...@@ -154,6 +156,7 @@ class TestMixCast(TestCastBase): ...@@ -154,6 +156,7 @@ class TestMixCast(TestCastBase):
def set_func(self): def set_func(self):
self.func = test_mix_cast self.func = test_mix_cast
@test_and_compare_with_new_ir(False)
def test_cast_result(self): def test_cast_result(self):
res = self.do_test().numpy() res = self.do_test().numpy()
self.assertTrue( self.assertTrue(
...@@ -186,6 +189,7 @@ class TestNotVarCast(TestCastBase): ...@@ -186,6 +189,7 @@ class TestNotVarCast(TestCastBase):
def set_func(self): def set_func(self):
self.func = test_not_var_cast self.func = test_not_var_cast
@test_and_compare_with_new_ir(False)
def test_cast_result(self): def test_cast_result(self):
res = self.do_test() res = self.do_test()
self.assertTrue(type(res) == int, msg='The casted dtype is not int.') self.assertTrue(type(res) == int, msg='The casted dtype is not int.')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册