未验证 提交 fa2f67a5 编写于 作者: P Paulina Gacek 提交者: GitHub

Tests for other dtypes corrected (#46836)

上级 78add057
...@@ -63,7 +63,7 @@ class TestNearestInterpMKLDNNOp(OpTest): ...@@ -63,7 +63,7 @@ class TestNearestInterpMKLDNNOp(OpTest):
pass pass
def init_data_type(self): def init_data_type(self):
pass self.dtype = np.float32
def setUp(self): def setUp(self):
self.op_type = "nearest_interp" self.op_type = "nearest_interp"
...@@ -72,7 +72,6 @@ class TestNearestInterpMKLDNNOp(OpTest): ...@@ -72,7 +72,6 @@ class TestNearestInterpMKLDNNOp(OpTest):
self.use_mkldnn = True self.use_mkldnn = True
self.input_shape = [1, 1, 2, 2] self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW' self.data_layout = 'NCHW'
self.dtype = np.float32
# priority: actual_shape > out_size > scale > out_h & out_w # priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1 self.out_h = 1
self.out_w = 1 self.out_w = 1
...@@ -176,11 +175,9 @@ class TestNearestNeighborInterpSame(TestNearestInterpMKLDNNOp): ...@@ -176,11 +175,9 @@ class TestNearestNeighborInterpSame(TestNearestInterpMKLDNNOp):
def create_test_class(parent): def create_test_class(parent):
'''
class TestFp32Case(parent): Create tests for int, uint8. By default parent class works on fp32.
'''
def init_data_type(self):
self.dtype = np.float32
class TestInt8Case(parent): class TestInt8Case(parent):
...@@ -192,12 +189,10 @@ def create_test_class(parent): ...@@ -192,12 +189,10 @@ def create_test_class(parent):
def init_data_type(self): def init_data_type(self):
self.dtype = np.uint8 self.dtype = np.uint8
TestFp32Case.__name__ = parent.__name__ TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8")
TestInt8Case.__name__ = parent.__name__ TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8")
TestUint8Case.__name__ = parent.__name__ globals()[TestInt8Case.__name__] = TestInt8Case
globals()[parent.__name__] = TestFp32Case globals()[TestUint8Case.__name__] = TestUint8Case
globals()[parent.__name__] = TestInt8Case
globals()[parent.__name__] = TestUint8Case
create_test_class(TestNearestInterpMKLDNNOp) create_test_class(TestNearestInterpMKLDNNOp)
...@@ -205,7 +200,6 @@ create_test_class(TestNearestInterpOpMKLDNNNHWC) ...@@ -205,7 +200,6 @@ create_test_class(TestNearestInterpOpMKLDNNNHWC)
create_test_class(TestNearestNeighborInterpMKLDNNCase2) create_test_class(TestNearestNeighborInterpMKLDNNCase2)
create_test_class(TestNearestNeighborInterpCase3) create_test_class(TestNearestNeighborInterpCase3)
create_test_class(TestNearestNeighborInterpCase4) create_test_class(TestNearestNeighborInterpCase4)
create_test_class(TestNearestInterpOpMKLDNNNHWC)
create_test_class(TestNearestNeighborInterpSame) create_test_class(TestNearestNeighborInterpSame)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -64,7 +64,7 @@ class TestNearestInterpV2MKLDNNOp(OpTest): ...@@ -64,7 +64,7 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
pass pass
def init_data_type(self): def init_data_type(self):
pass self.dtype = np.float32
def setUp(self): def setUp(self):
self.op_type = "nearest_interp_v2" self.op_type = "nearest_interp_v2"
...@@ -73,7 +73,6 @@ class TestNearestInterpV2MKLDNNOp(OpTest): ...@@ -73,7 +73,6 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
self.use_mkldnn = True self.use_mkldnn = True
self.input_shape = [1, 1, 2, 2] self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW' self.data_layout = 'NCHW'
self.dtype = np.float32
# priority: actual_shape > out_size > scale > out_h & out_w # priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1 self.out_h = 1
self.out_w = 1 self.out_w = 1
...@@ -196,11 +195,9 @@ class TestNearestNeighborInterpV2MKLDNNSame(TestNearestInterpV2MKLDNNOp): ...@@ -196,11 +195,9 @@ class TestNearestNeighborInterpV2MKLDNNSame(TestNearestInterpV2MKLDNNOp):
def create_test_class(parent): def create_test_class(parent):
'''
class TestFp32Case(parent): Create tests for bf16, int, uint8. By default parent class works on fp32.
'''
def init_data_type(self):
self.dtype = np.float32
class TestBf16Case(parent): class TestBf16Case(parent):
...@@ -217,11 +214,9 @@ def create_test_class(parent): ...@@ -217,11 +214,9 @@ def create_test_class(parent):
def init_data_type(self): def init_data_type(self):
self.dtype = np.uint8 self.dtype = np.uint8
TestFp32Case.__name__ = "{0}_{1}".format(parent.__name__, "FP32")
TestBf16Case.__name__ = "{0}_{1}".format(parent.__name__, "BF16") TestBf16Case.__name__ = "{0}_{1}".format(parent.__name__, "BF16")
TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8") TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8")
TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8") TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8")
globals()[TestFp32Case.__name__] = TestFp32Case
globals()[TestBf16Case.__name__] = TestBf16Case globals()[TestBf16Case.__name__] = TestBf16Case
globals()[TestInt8Case.__name__] = TestInt8Case globals()[TestInt8Case.__name__] = TestInt8Case
globals()[TestUint8Case.__name__] = TestUint8Case globals()[TestUint8Case.__name__] = TestUint8Case
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册