未验证 提交 a2382986 编写于 作者: L lidanqing 提交者: GitHub

Skip some conv2d_int8 tests in windows (#30128)

上级 a60893f6
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import os
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2DOp from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2DOp
...@@ -28,6 +28,8 @@ def conv2d_forward_refer(input, filter, group, conv_param): ...@@ -28,6 +28,8 @@ def conv2d_forward_refer(input, filter, group, conv_param):
return out return out
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestConv2DInt8Op(TestConv2DOp): class TestConv2DInt8Op(TestConv2DOp):
def setUp(self): def setUp(self):
self.op_type = "conv2d" self.op_type = "conv2d"
...@@ -289,43 +291,31 @@ def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual): ...@@ -289,43 +291,31 @@ def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual):
def create_test_int8_class(parent): def create_test_int8_class(parent):
#--------------------test conv2d s8 in and u8 out-------------------- #--------------------test conv2d s8 in and u8 out--------------------
class TestS8U8Case(parent): class TestS8U8Case(parent):
def init_data_type(self): def init_data_type(self):
init_data_type_with_fusion(self, np.int8, "relu", False) init_data_type_with_fusion(self, np.int8, "relu", False)
#--------------------test conv2d s8 in and s8 out-------------------- #--------------------test conv2d s8 in and s8 out--------------------
class TestS8S8Case(parent): class TestS8S8Case(parent):
def init_data_type(self): def init_data_type(self):
init_data_type_with_fusion(self, np.int8, "", False) init_data_type_with_fusion(self, np.int8, "", False)
#--------------------test conv2d u8 in and s8 out-------------------- #--------------------test conv2d u8 in and s8 out--------------------
class TestU8S8Case(parent): class TestU8S8Case(parent):
def init_data_type(self): def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, "", False) init_data_type_with_fusion(self, np.uint8, "", False)
#--------------------test conv2d u8 in and u8 out without residual fuse-------------------- #--------------------test conv2d u8 in and u8 out without residual fuse--------------------
class TestU8U8Case(parent): class TestU8U8Case(parent):
def init_data_type(self): def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, "relu", False) init_data_type_with_fusion(self, np.uint8, "relu", False)
#--------------------test conv2d s8 in and u8 out with residual fuse--------------------
class TestS8U8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, "relu", True)
#--------------------test conv2d s8 in and s8 out with residual fuse-------------------- #--------------------test conv2d s8 in and s8 out with residual fuse--------------------
class TestS8S8ResCase(parent): class TestS8S8ResCase(parent):
def init_data_type(self): def init_data_type(self):
init_data_type_with_fusion(self, np.int8, "", True) init_data_type_with_fusion(self, np.int8, "", True)
#--------------------test conv2d u8 in and s8 out with residual fuse-------------------- #--------------------test conv2d u8 in and s8 out with residual fuse--------------------
class TestU8S8ResCase(parent): class TestU8S8ResCase(parent):
def init_data_type(self): def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, "", True) init_data_type_with_fusion(self, np.uint8, "", True)
...@@ -334,8 +324,7 @@ def create_test_int8_class(parent): ...@@ -334,8 +324,7 @@ def create_test_int8_class(parent):
cls_name_s8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0") cls_name_s8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0")
cls_name_u8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0") cls_name_u8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0")
cls_name_u8u8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "1") cls_name_u8u8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "1")
cls_name_s8u8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
"1", "1")
cls_name_s8s8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__, cls_name_s8s8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
"0", "1") "0", "1")
cls_name_u8s8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__, cls_name_u8s8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
...@@ -344,17 +333,27 @@ def create_test_int8_class(parent): ...@@ -344,17 +333,27 @@ def create_test_int8_class(parent):
TestS8S8Case.__name__ = cls_name_s8s8 TestS8S8Case.__name__ = cls_name_s8s8
TestU8S8Case.__name__ = cls_name_u8s8 TestU8S8Case.__name__ = cls_name_u8s8
TestU8U8Case.__name__ = cls_name_u8u8 TestU8U8Case.__name__ = cls_name_u8u8
TestS8U8ResCase.__name__ = cls_name_s8u8_re_1
TestS8S8ResCase.__name__ = cls_name_s8s8_re_1 TestS8S8ResCase.__name__ = cls_name_s8s8_re_1
TestU8S8ResCase.__name__ = cls_name_u8s8_re_1 TestU8S8ResCase.__name__ = cls_name_u8s8_re_1
globals()[cls_name_s8u8] = TestS8U8Case globals()[cls_name_s8u8] = TestS8U8Case
globals()[cls_name_s8s8] = TestS8S8Case globals()[cls_name_s8s8] = TestS8S8Case
globals()[cls_name_u8s8] = TestU8S8Case globals()[cls_name_u8s8] = TestU8S8Case
globals()[cls_name_u8u8] = TestU8U8Case globals()[cls_name_u8u8] = TestU8U8Case
globals()[cls_name_s8u8_re_1] = TestS8U8ResCase
globals()[cls_name_s8s8_re_1] = TestS8S8ResCase globals()[cls_name_s8s8_re_1] = TestS8S8ResCase
globals()[cls_name_u8s8_re_1] = TestU8S8ResCase globals()[cls_name_u8s8_re_1] = TestU8S8ResCase
if os.name != 'nt':
#--------------------test conv2d s8 in and u8 out with residual fuse--------------------
class TestS8U8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, "relu", True)
cls_name_s8u8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
"1", "1")
TestS8U8ResCase.__name__ = cls_name_s8u8_re_1
globals()[cls_name_s8u8_re_1] = TestS8U8ResCase
create_test_int8_class(TestConv2DInt8Op) create_test_int8_class(TestConv2DInt8Op)
create_test_int8_class(TestWithPad) create_test_int8_class(TestWithPad)
...@@ -387,4 +386,6 @@ class TestConv2DOp_Valid_INT_MKLDNN(TestConv2DOp_AsyPadding_INT_MKLDNN): ...@@ -387,4 +386,6 @@ class TestConv2DOp_Valid_INT_MKLDNN(TestConv2DOp_AsyPadding_INT_MKLDNN):
if __name__ == '__main__': if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main() unittest.main()
...@@ -82,7 +82,6 @@ diable_wingpu_test="^test_gradient_clip$|\ ...@@ -82,7 +82,6 @@ diable_wingpu_test="^test_gradient_clip$|\
^test_rnn_op$|\ ^test_rnn_op$|\
^test_simple_rnn_op$|\ ^test_simple_rnn_op$|\
^test_lstm_cudnn_op$|\ ^test_lstm_cudnn_op$|\
^test_conv2d_int8_mkldnn_op$|\
^test_crypto$|\ ^test_crypto$|\
^test_program_prune_backward$|\ ^test_program_prune_backward$|\
^test_imperative_ocr_attention_model$|\ ^test_imperative_ocr_attention_model$|\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册