未验证 提交 a2382986 编写于 作者: L lidanqing 提交者: GitHub

Skip some conv2d_int8 tests in windows (#30128)

上级 a60893f6
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import os
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2DOp
......@@ -28,6 +28,8 @@ def conv2d_forward_refer(input, filter, group, conv_param):
return out
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestConv2DInt8Op(TestConv2DOp):
def setUp(self):
self.op_type = "conv2d"
......@@ -289,43 +291,31 @@ def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual):
def create_test_int8_class(parent):
#--------------------test conv2d s8 in and u8 out--------------------
class TestS8U8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, "relu", False)
#--------------------test conv2d s8 in and s8 out--------------------
class TestS8S8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, "", False)
#--------------------test conv2d u8 in and s8 out--------------------
class TestU8S8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, "", False)
#--------------------test conv2d u8 in and u8 out without residual fuse--------------------
class TestU8U8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, "relu", False)
#--------------------test conv2d s8 in and u8 out with residual fuse--------------------
class TestS8U8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, "relu", True)
#--------------------test conv2d s8 in and s8 out with residual fuse--------------------
class TestS8S8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, "", True)
#--------------------test conv2d u8 in and s8 out with residual fuse--------------------
class TestU8S8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, "", True)
......@@ -334,8 +324,7 @@ def create_test_int8_class(parent):
cls_name_s8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0")
cls_name_u8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0")
cls_name_u8u8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "1")
cls_name_s8u8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
"1", "1")
cls_name_s8s8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
"0", "1")
cls_name_u8s8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
......@@ -344,17 +333,27 @@ def create_test_int8_class(parent):
TestS8S8Case.__name__ = cls_name_s8s8
TestU8S8Case.__name__ = cls_name_u8s8
TestU8U8Case.__name__ = cls_name_u8u8
TestS8U8ResCase.__name__ = cls_name_s8u8_re_1
TestS8S8ResCase.__name__ = cls_name_s8s8_re_1
TestU8S8ResCase.__name__ = cls_name_u8s8_re_1
globals()[cls_name_s8u8] = TestS8U8Case
globals()[cls_name_s8s8] = TestS8S8Case
globals()[cls_name_u8s8] = TestU8S8Case
globals()[cls_name_u8u8] = TestU8U8Case
globals()[cls_name_s8u8_re_1] = TestS8U8ResCase
globals()[cls_name_s8s8_re_1] = TestS8S8ResCase
globals()[cls_name_u8s8_re_1] = TestU8S8ResCase
if os.name != 'nt':
#--------------------test conv2d s8 in and u8 out with residual fuse--------------------
class TestS8U8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, "relu", True)
cls_name_s8u8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
"1", "1")
TestS8U8ResCase.__name__ = cls_name_s8u8_re_1
globals()[cls_name_s8u8_re_1] = TestS8U8ResCase
create_test_int8_class(TestConv2DInt8Op)
create_test_int8_class(TestWithPad)
......@@ -387,4 +386,6 @@ class TestConv2DOp_Valid_INT_MKLDNN(TestConv2DOp_AsyPadding_INT_MKLDNN):
if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()
......@@ -82,7 +82,6 @@ diable_wingpu_test="^test_gradient_clip$|\
^test_rnn_op$|\
^test_simple_rnn_op$|\
^test_lstm_cudnn_op$|\
^test_conv2d_int8_mkldnn_op$|\
^test_crypto$|\
^test_program_prune_backward$|\
^test_imperative_ocr_attention_model$|\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册