未验证 提交 fa2f67a5 编写于 作者: P Paulina Gacek 提交者: GitHub

Tests for other dtypes corrected (#46836)

上级 78add057
......@@ -63,7 +63,7 @@ class TestNearestInterpMKLDNNOp(OpTest):
pass
def init_data_type(self):
pass
self.dtype = np.float32
def setUp(self):
self.op_type = "nearest_interp"
......@@ -72,7 +72,6 @@ class TestNearestInterpMKLDNNOp(OpTest):
self.use_mkldnn = True
self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW'
self.dtype = np.float32
# priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1
self.out_w = 1
......@@ -176,11 +175,9 @@ class TestNearestNeighborInterpSame(TestNearestInterpMKLDNNOp):
def create_test_class(parent):
class TestFp32Case(parent):
def init_data_type(self):
self.dtype = np.float32
'''
Create tests for int, uint8. By default parent class works on fp32.
'''
class TestInt8Case(parent):
......@@ -192,12 +189,10 @@ def create_test_class(parent):
def init_data_type(self):
self.dtype = np.uint8
TestFp32Case.__name__ = parent.__name__
TestInt8Case.__name__ = parent.__name__
TestUint8Case.__name__ = parent.__name__
globals()[parent.__name__] = TestFp32Case
globals()[parent.__name__] = TestInt8Case
globals()[parent.__name__] = TestUint8Case
TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8")
TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8")
globals()[TestInt8Case.__name__] = TestInt8Case
globals()[TestUint8Case.__name__] = TestUint8Case
create_test_class(TestNearestInterpMKLDNNOp)
......@@ -205,7 +200,6 @@ create_test_class(TestNearestInterpOpMKLDNNNHWC)
create_test_class(TestNearestNeighborInterpMKLDNNCase2)
create_test_class(TestNearestNeighborInterpCase3)
create_test_class(TestNearestNeighborInterpCase4)
create_test_class(TestNearestInterpOpMKLDNNNHWC)
create_test_class(TestNearestNeighborInterpSame)
if __name__ == "__main__":
......
......@@ -64,7 +64,7 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
pass
def init_data_type(self):
pass
self.dtype = np.float32
def setUp(self):
self.op_type = "nearest_interp_v2"
......@@ -73,7 +73,6 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
self.use_mkldnn = True
self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW'
self.dtype = np.float32
# priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1
self.out_w = 1
......@@ -196,11 +195,9 @@ class TestNearestNeighborInterpV2MKLDNNSame(TestNearestInterpV2MKLDNNOp):
def create_test_class(parent):
class TestFp32Case(parent):
def init_data_type(self):
self.dtype = np.float32
'''
Create tests for bf16, int, uint8. By default parent class works on fp32.
'''
class TestBf16Case(parent):
......@@ -217,11 +214,9 @@ def create_test_class(parent):
def init_data_type(self):
self.dtype = np.uint8
TestFp32Case.__name__ = "{0}_{1}".format(parent.__name__, "FP32")
TestBf16Case.__name__ = "{0}_{1}".format(parent.__name__, "BF16")
TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8")
TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8")
globals()[TestFp32Case.__name__] = TestFp32Case
globals()[TestBf16Case.__name__] = TestBf16Case
globals()[TestInt8Case.__name__] = TestInt8Case
globals()[TestUint8Case.__name__] = TestUint8Case
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册