未验证 提交 b67715a4 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] add decorator for dy2st test with new ir (#55840)

* add decorator for new_ir_test

* fix bug and only test in ci-coverage

* fix bug and only test in ci-coverage

* fix

* fix bugs

* fix

* fix
上级 697c712f
......@@ -3962,6 +3962,7 @@ function main() {
check_coverage_build
;;
gpu_cicheck_coverage)
export FLAGS_NEW_IR_DY2ST_TEST=True
parallel_test
check_coverage
;;
......
......@@ -14,6 +14,7 @@
import collections
import inspect
import os
import textwrap
import threading
import warnings
......@@ -192,6 +193,7 @@ class CacheKey:
'class_instance',
'kwargs',
'_spec_names_id',
'_new_ir_flags',
]
def __init__(
......@@ -221,6 +223,9 @@ class CacheKey:
self._spec_names_id = _hash_spec_names(
input_args_with_spec, input_kwargs_with_spec
)
self._new_ir_flags = os.environ.get(
'FLAGS_enable_new_ir_in_executor', None
)
@classmethod
def from_func_and_args(cls, function_spec, args, kwargs, class_instance):
......@@ -264,6 +269,7 @@ class CacheKey:
self.class_instance,
with_hook,
is_train,
self._new_ir_flags,
)
)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import wraps
import numpy as np
from paddle import set_flags, static
from paddle.fluid import core
def test_with_new_ir(func):
@wraps(func)
def impl(*args, **kwargs):
ir_outs = None
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
try:
new_ir_flag = 'FLAGS_enable_new_ir_in_executor'
os.environ[new_ir_flag] = 'True'
set_flags({new_ir_flag: True})
ir_outs = func(*args, **kwargs)
finally:
del os.environ[new_ir_flag]
set_flags({new_ir_flag: False})
return ir_outs
return impl
def test_and_compare_with_new_ir(need_check_output: bool = True):
def decorator(func):
@wraps(func)
def impl(*args, **kwargs):
outs = func(*args, **kwargs)
if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled():
return outs
# only run in CI-Coverage
if os.environ.get('FLAGS_NEW_IR_DY2ST_TEST', None) is None:
return outs
ir_outs = test_with_new_ir(func)(*args, **kwargs)
if not need_check_output:
return outs
for i in range(len(outs)):
np.testing.assert_array_equal(
outs[i],
ir_outs[i],
err_msg='Dy2St Unittest Check ('
+ func.__name__
+ ') has diff '
+ '\nExpect '
+ str(outs[i])
+ '\n'
+ 'But Got'
+ str(ir_outs[i]),
)
return outs
return impl
return decorator
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from dy2st_test_utils import test_and_compare_with_new_ir
from paddle import fluid
from paddle.jit.api import to_static
......@@ -88,6 +89,7 @@ class TestCastBase(unittest.TestCase):
res = self.func(self.input)
return res
@test_and_compare_with_new_ir(False)
def test_cast_result(self):
res = self.do_test().numpy()
self.assertTrue(
......@@ -154,6 +156,7 @@ class TestMixCast(TestCastBase):
def set_func(self):
self.func = test_mix_cast
@test_and_compare_with_new_ir(False)
def test_cast_result(self):
res = self.do_test().numpy()
self.assertTrue(
......@@ -186,6 +189,7 @@ class TestNotVarCast(TestCastBase):
def set_func(self):
self.func = test_not_var_cast
@test_and_compare_with_new_ir(False)
def test_cast_result(self):
res = self.do_test()
self.assertTrue(type(res) == int, msg='The casted dtype is not int.')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册